[hotfix] fix sharded optim step and clip_grad_norm (#1226)

This commit is contained in:
ver217
2022-07-08 13:34:48 +08:00
committed by GitHub
parent f071b500b6
commit a45ddf2d5f
2 changed files with 10 additions and 4 deletions

View File

@@ -195,7 +195,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# Make sure the grads are in fp32
assert param.grad.dtype == torch.float, \
f'expected gradient to be dtype torch.float, but got {param.grad.type()}'
if hasattr(param, 'zero_is_sharded'):
if hasattr(param, 'colo_attr') and param.colo_attr.sharded_data_tensor.is_sharded:
has_zero_shared_param = True
params.append(param)
@@ -234,7 +234,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
if is_model_parallel_parameter(p):
reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type)
tensor_parallel_grads.append(p.grad.data / reductor)
elif hasattr(p, 'zero_is_sharded'):
elif hasattr(p, 'colo_attr') and p.colo_attr.sharded_data_tensor.is_sharded:
zero_sharded_grads.append(p.grad.data)
else:
no_tensor_parallel_grads.append(p.grad.data)