mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
adapted for sequence parallel (#163)
This commit is contained in:
@@ -47,16 +47,16 @@ def free_port():
|
||||
continue
|
||||
|
||||
|
||||
def sync_model_param_in_dp(model):
|
||||
def sync_model_param(model, parallel_mode):
|
||||
'''Make sure data parameters are consistent during Data Parallel Mode
|
||||
|
||||
:param model: A pyTorch nn.model on whose parameters you check the consistency
|
||||
'''
|
||||
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
|
||||
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
|
||||
for param in model.parameters():
|
||||
ranks = gpc.get_ranks_in_group(ParallelMode.DATA)
|
||||
ranks = gpc.get_ranks_in_group(parallel_mode)
|
||||
dist.broadcast(
|
||||
param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA))
|
||||
param, src=ranks[0], group=gpc.get_group(parallel_mode))
|
||||
|
||||
|
||||
def is_dp_rank_0():
|
||||
@@ -79,6 +79,10 @@ def is_using_pp():
|
||||
return gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1
|
||||
|
||||
|
||||
def is_using_sequence():
|
||||
return gpc.is_initialized(ParallelMode.SEQUENCE) and gpc.get_world_size(ParallelMode.SEQUENCE) > 1
|
||||
|
||||
|
||||
@contextmanager
|
||||
def conditional_context(context_manager, enable=True):
|
||||
if enable:
|
||||
@@ -240,16 +244,20 @@ def count_zeros_fp32(parameters):
|
||||
num_zeros = grad.numel() - torch.count_nonzero(grad)
|
||||
total_num_zeros = num_zeros + total_num_zeros
|
||||
|
||||
total_num_zeros = torch.IntTensor([int(total_num_zeros)]).cuda()
|
||||
|
||||
# Sum across all model-parallel GPUs.
|
||||
ops = []
|
||||
ops.append(dist.all_reduce(total_num_zeros,
|
||||
op=dist.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.TENSOR),
|
||||
async_op=True))
|
||||
ops.append(dist.all_reduce(total_num_zeros,
|
||||
op=dist.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.PIPELINE),
|
||||
async_op=True))
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
ops.append(dist.all_reduce(total_num_zeros,
|
||||
op=dist.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.PIPELINE),
|
||||
async_op=True))
|
||||
|
||||
for req in ops:
|
||||
req.wait()
|
||||
total_num_zeros = total_num_zeros.item()
|
||||
|
Reference in New Issue
Block a user