mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[zero] allow passing process group to zero12 (#4153)
* allow passing process group to zero12 * union tp-zero and normal-zero * polish code
This commit is contained in:
@@ -3,8 +3,9 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import inf
|
||||
from torch import Tensor, inf
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.tensor import ColoParameter
|
||||
from colossalai.utils import is_model_parallel_parameter
|
||||
@@ -194,25 +195,20 @@ def calculate_global_norm_from_list(norm_list):
|
||||
return math.sqrt(total_norm)
|
||||
|
||||
|
||||
def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
|
||||
def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGroup, norm_type: int = 2) -> int:
|
||||
"""Clips gradient norm of an iterable of parameters.
|
||||
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
|
||||
added functionality to handle model parallel parameters. Note that
|
||||
the gradients are modified in place.
|
||||
Arguments:
|
||||
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
||||
single Tensor that will have gradients normalized
|
||||
max_norm (float or int): max norm of the gradients
|
||||
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
||||
infinity norm.
|
||||
Returns:
|
||||
Total norm of the parameters (viewed as a single vector).
|
||||
"""
|
||||
added functionality to handle model parallel parameters.
|
||||
|
||||
if mp_group is None:
|
||||
mp_rank = 0
|
||||
else:
|
||||
mp_rank = dist.get_rank(mp_group)
|
||||
Args:
|
||||
gradients (Tensor): The gradients to compute norm
|
||||
dp_group (ProcessGroup): The process group of ZeRO Data Parallelism
|
||||
tp_group (ProcessGroup): The process group of Tensor Parallelism
|
||||
norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2.
|
||||
|
||||
Returns:
|
||||
int: The total norm of given gradients
|
||||
"""
|
||||
|
||||
norm_type = float(norm_type)
|
||||
if norm_type == inf:
|
||||
@@ -221,29 +217,21 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
|
||||
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group)
|
||||
|
||||
# Take max across all GPUs.
|
||||
if mp_group is not None:
|
||||
if tp_group is not None:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX)
|
||||
total_norm = total_norm_cuda[0].item()
|
||||
else:
|
||||
total_norm = 0.0
|
||||
# if dist.get_rank() == 0:
|
||||
# logger.info(f"Total Norm beginning {total_norm}")
|
||||
|
||||
for g, p in zip(gradients, params):
|
||||
# Pipeline parallelism may replicate parameters. Avoid multi-counting.
|
||||
tp_param_flag = False
|
||||
if is_model_parallel_parameter(p) or (isinstance(p, ColoParameter) and not p.is_replicate()):
|
||||
tp_param_flag = True
|
||||
if tp_param_flag or mp_rank == 0:
|
||||
param_norm = g.data.double().norm(2)
|
||||
total_norm += param_norm.item()**2
|
||||
for g in gradients:
|
||||
param_norm = g.data.double().norm(2)
|
||||
total_norm += param_norm.item()**2
|
||||
|
||||
# Sum across all model parallel GPUs.
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group)
|
||||
|
||||
if mp_group is not None:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=mp_group)
|
||||
if tp_group is not None:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=tp_group)
|
||||
|
||||
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
|
||||
|
||||
|
Reference in New Issue
Block a user