[hotfix] fix norm type error in zero optimizer (#4795)

This commit is contained in:
littsk
2023-09-27 10:35:24 +08:00
committed by GitHub
parent da15fdb9ca
commit 54b3ad8924

View File

@@ -221,8 +221,8 @@ def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGro
else:
total_norm = 0.0
for g in gradients:
param_norm = g.data.double().norm(2)
total_norm += param_norm.item() ** 2
param_norm = g.data.double().norm(norm_type)
total_norm += param_norm.item() ** norm_type
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])