mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 23:18:36 +00:00
[zero] fix gradient clipping in hybrid parallelism (#2521)
* [zero] fix gradient clipping in hybrid parallelism * [testing] change model name to avoid pytest warning * [hotfix] fix unit testing
This commit is contained in:
@@ -6,9 +6,7 @@ import torch.distributed as dist
|
||||
from torch._six import inf
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.tensor import ColoParameter
|
||||
from colossalai.utils import is_model_parallel_parameter
|
||||
|
||||
|
||||
@@ -225,7 +223,10 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
|
||||
|
||||
for g, p in zip(gradients, params):
|
||||
# Pipeline parallelism may replicate parameters. Avoid multi-counting.
|
||||
if is_model_parallel_parameter(p) or mp_rank == 0:
|
||||
tp_param_flag = False
|
||||
if is_model_parallel_parameter(p) or (isinstance(p, ColoParameter) and not p.is_replicate()):
|
||||
tp_param_flag = True
|
||||
if tp_param_flag or mp_rank == 0:
|
||||
param_norm = g.data.double().norm(2)
|
||||
total_norm += param_norm.item()**2
|
||||
|
||||
@@ -234,7 +235,7 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
|
||||
torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group)
|
||||
|
||||
if mp_group is not None:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM)
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=mp_group)
|
||||
|
||||
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
|
||||
|
||||
|
Reference in New Issue
Block a user