adapted for sequence parallel (#163)

This commit is contained in:
Frank Lee
2022-01-20 13:44:51 +08:00
committed by GitHub
parent a2e649da39
commit e2089c5c15
17 changed files with 432 additions and 119 deletions

View File

@@ -47,16 +47,16 @@ def free_port():
continue
def sync_model_param_in_dp(model):
def sync_model_param(model, parallel_mode):
'''Make sure data parameters are consistent during Data Parallel Mode
:param model: A pyTorch nn.model on whose parameters you check the consistency
'''
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
for param in model.parameters():
ranks = gpc.get_ranks_in_group(ParallelMode.DATA)
ranks = gpc.get_ranks_in_group(parallel_mode)
dist.broadcast(
param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA))
param, src=ranks[0], group=gpc.get_group(parallel_mode))
def is_dp_rank_0():
@@ -79,6 +79,10 @@ def is_using_pp():
return gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1
def is_using_sequence():
return gpc.is_initialized(ParallelMode.SEQUENCE) and gpc.get_world_size(ParallelMode.SEQUENCE) > 1
@contextmanager
def conditional_context(context_manager, enable=True):
if enable:
@@ -240,16 +244,20 @@ def count_zeros_fp32(parameters):
num_zeros = grad.numel() - torch.count_nonzero(grad)
total_num_zeros = num_zeros + total_num_zeros
total_num_zeros = torch.IntTensor([int(total_num_zeros)]).cuda()
# Sum across all model-parallel GPUs.
ops = []
ops.append(dist.all_reduce(total_num_zeros,
op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=True))
ops.append(dist.all_reduce(total_num_zeros,
op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PIPELINE),
async_op=True))
if gpc.is_initialized(ParallelMode.PIPELINE):
ops.append(dist.all_reduce(total_num_zeros,
op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PIPELINE),
async_op=True))
for req in ops:
req.wait()
total_num_zeros = total_num_zeros.item()