[feature] Add clip_grad_norm for hybrid_parallel_plugin (#4837)

* Add clip_grad_norm for hibrid_parallel_plugin

* polish code

* add unittests

* Move tp to a higher-level optimizer interface.

* bug fix

* polish code
This commit is contained in:
littsk
2023-10-12 11:32:37 +08:00
committed by GitHub
parent df63564184
commit 83b52c56cd
8 changed files with 1158 additions and 90 deletions

View File

@@ -3,9 +3,7 @@ from typing import Optional
import torch
import torch.distributed as dist
from torch import Tensor, inf
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup
def flatten(input_):
@@ -192,53 +190,6 @@ def calculate_global_norm_from_list(norm_list):
total_norm += norm**2.0
return math.sqrt(total_norm)
def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGroup, norm_type: int = 2) -> int:
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters.
Args:
gradients (Tensor): The gradients to compute norm
dp_group (ProcessGroup): The process group of ZeRO Data Parallelism
tp_group (ProcessGroup): The process group of Tensor Parallelism
norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2.
Returns:
int: The total norm of given gradients
"""
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(g.data.abs().max() for g in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group)
# Take max across all GPUs.
if tp_group is not None:
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX)
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.0
for g in gradients:
param_norm = g.data.double().norm(norm_type)
total_norm += param_norm.item() ** norm_type
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group)
if tp_group is not None:
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=tp_group)
total_norm = total_norm_cuda[0].item() ** (1.0 / norm_type)
if total_norm == float("inf") or total_norm == -float("inf") or total_norm != total_norm:
total_norm = -1
return total_norm
def sync_tensor(flat_tensor, tensor_list):
"""
Synchronize the flattened tensor and unflattened tensor list. When