[zero] fix gradient clipping in hybrid parallelism (#2521)

* [zero] fix gradient clipping in hybrid parallelism

* [testing] change model name to avoid pytest warning

* [hotfix] fix unit testing
This commit is contained in:
HELSON
2023-01-29 15:09:57 +08:00
committed by GitHub
parent fd8d19a6e7
commit 077a5cdde4
6 changed files with 45 additions and 26 deletions

View File

@@ -6,9 +6,7 @@ import torch.distributed as dist
from torch._six import inf
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.tensor import ProcessGroup
from colossalai.tensor import ColoParameter
from colossalai.utils import is_model_parallel_parameter
@@ -225,7 +223,10 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
for g, p in zip(gradients, params):
# Pipeline parallelism may replicate parameters. Avoid multi-counting.
if is_model_parallel_parameter(p) or mp_rank == 0:
tp_param_flag = False
if is_model_parallel_parameter(p) or (isinstance(p, ColoParameter) and not p.is_replicate()):
tp_param_flag = True
if tp_param_flag or mp_rank == 0:
param_norm = g.data.double().norm(2)
total_norm += param_norm.item()**2
@@ -234,7 +235,7 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group)
if mp_group is not None:
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM)
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=mp_group)
total_norm = total_norm_cuda[0].item()**(1. / norm_type)