[zero] allow passing process group to zero12 (#4153)

* allow passing process group to zero12

* union tp-zero and normal-zero

* polish code
This commit is contained in:
LuGY
2023-07-04 17:41:28 +08:00
committed by Hongxin Liu
parent 79cf1b5f33
commit c668801d36
4 changed files with 41 additions and 90 deletions

View File

@@ -3,8 +3,9 @@ from typing import Optional
import torch
import torch.distributed as dist
from torch import inf
from torch import Tensor, inf
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup
from colossalai.tensor import ColoParameter
from colossalai.utils import is_model_parallel_parameter
@@ -194,25 +195,20 @@ def calculate_global_norm_from_list(norm_list):
return math.sqrt(total_norm)
def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGroup, norm_type: int = 2) -> int:
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
added functionality to handle model parallel parameters.
if mp_group is None:
mp_rank = 0
else:
mp_rank = dist.get_rank(mp_group)
Args:
gradients (Tensor): The gradients to compute norm
dp_group (ProcessGroup): The process group of ZeRO Data Parallelism
tp_group (ProcessGroup): The process group of Tensor Parallelism
norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2.
Returns:
int: The total norm of given gradients
"""
norm_type = float(norm_type)
if norm_type == inf:
@@ -221,29 +217,21 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group)
# Take max across all GPUs.
if mp_group is not None:
if tp_group is not None:
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX)
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.0
# if dist.get_rank() == 0:
# logger.info(f"Total Norm beginning {total_norm}")
for g, p in zip(gradients, params):
# Pipeline parallelism may replicate parameters. Avoid multi-counting.
tp_param_flag = False
if is_model_parallel_parameter(p) or (isinstance(p, ColoParameter) and not p.is_replicate()):
tp_param_flag = True
if tp_param_flag or mp_rank == 0:
param_norm = g.data.double().norm(2)
total_norm += param_norm.item()**2
for g in gradients:
param_norm = g.data.double().norm(2)
total_norm += param_norm.item()**2
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group)
if mp_group is not None:
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=mp_group)
if tp_group is not None:
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=tp_group)
total_norm = total_norm_cuda[0].item()**(1. / norm_type)