[feature] Add clip_grad_norm for hybrid_parallel_plugin (#4837)

* Add clip_grad_norm for hibrid_parallel_plugin

* polish code

* add unittests

* Move tp to a higher-level optimizer interface.

* bug fix

* polish code
This commit is contained in:
littsk
2023-10-12 11:32:37 +08:00
committed by GitHub
parent df63564184
commit 83b52c56cd
8 changed files with 1158 additions and 90 deletions

View File

@@ -3,9 +3,7 @@ from typing import Optional
import torch
import torch.distributed as dist
from torch import Tensor, inf
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup
def flatten(input_):
@@ -192,53 +190,6 @@ def calculate_global_norm_from_list(norm_list):
total_norm += norm**2.0
return math.sqrt(total_norm)
def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGroup, norm_type: int = 2) -> int:
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters.
Args:
gradients (Tensor): The gradients to compute norm
dp_group (ProcessGroup): The process group of ZeRO Data Parallelism
tp_group (ProcessGroup): The process group of Tensor Parallelism
norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2.
Returns:
int: The total norm of given gradients
"""
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(g.data.abs().max() for g in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group)
# Take max across all GPUs.
if tp_group is not None:
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX)
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.0
for g in gradients:
param_norm = g.data.double().norm(norm_type)
total_norm += param_norm.item() ** norm_type
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group)
if tp_group is not None:
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=tp_group)
total_norm = total_norm_cuda[0].item() ** (1.0 / norm_type)
if total_norm == float("inf") or total_norm == -float("inf") or total_norm != total_norm:
total_norm = -1
return total_norm
def sync_tensor(flat_tensor, tensor_list):
"""
Synchronize the flattened tensor and unflattened tensor list. When

View File

@@ -21,6 +21,8 @@ class GradientStore(BaseStore):
# for zero2, it's `param_id: [grad_local_rank]`
self._working_index = 0 if partition_grad else self._local_rank
self.grad_to_param_mapping = dict()
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
"""Return list of gradient slices of a specific parameter
@@ -54,6 +56,8 @@ class GradientStore(BaseStore):
else:
self._grads_of_params[group_id][param_id].append(grad)
self.grad_to_param_mapping[id(grad)] = param_id
def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int):
"""Add a gradient slice on an existing slice of the parameter's gradient
Used when no_sync is not activated.
@@ -83,8 +87,37 @@ class GradientStore(BaseStore):
return grad_list
def get_working_grad_by_param_id(self, param_id) -> Tensor:
"""
Return the working gradient for the specified parameter.
Args:
param_id (int): The index of the parameter.
Returns:
Tensor: The the working gradient slices for the specified param_id.
"""
for group in self._grads_of_params.values():
if param_id in group.keys():
return group[param_id][self._working_index]
raise KeyError(f"Working gradient for param_id {param_id} not found.")
def reset_grads_by_group_id(self, group_id: int):
self._grads_of_params[group_id] = dict()
def reset_all_gradients(self):
self._grads_of_params = dict()
def get_param_id_for_grad(self, grad: Tensor) -> int:
"""Return the id of a parameter which the gradient slice belongs to
Args:
grad (Tensor): the gradient slice
Returns:
int: the id of a parameter which the gradient slice belongs to
"""
return self.grad_to_param_mapping[id(grad)]

View File

@@ -2,11 +2,12 @@
import copy
from contextlib import contextmanager
from functools import partial
from typing import Dict, Iterator, Optional, Tuple
from typing import Dict, Iterator, List, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor, inf
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
@@ -21,14 +22,7 @@ from colossalai.logging import get_dist_logger
# from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.utils.cuda import get_current_device
from ._utils import (
calculate_global_norm_from_list,
compute_norm,
flatten,
has_inf_or_nan,
release_param_grad,
sync_tensor,
)
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
from .bookkeeping import BucketStore, GradientStore, ParameterStore
@@ -80,7 +74,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp
forced_dtype: Optional[torch.dtype] = None,
):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
@@ -101,8 +94,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self._local_rank = dist.get_rank(group=self.dp_pg)
self._world_size = dist.get_world_size(group=self.dp_pg)
self.tp_pg = tp_process_group
# working and master params for mixed precision training
self._working_param_groups = dict()
self._master_param_groups_of_current_rank = dict()
@@ -433,7 +424,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# compute norm
working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
norm_group = compute_norm(gradients=working_grads, dp_group=self.dp_pg, tp_group=self.tp_pg)
norm_group = self._compute_grad_norm(gradients=working_grads)
norm_groups.append(norm_group)
self._grad_store.reset_grads_by_group_id(group_id)
@@ -467,6 +458,44 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
r"""
Compute and return the gradient norm for gradient clipping.
Args:
gradients (List[Tensor]): The gradients to compute norm
norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2.
Returns:
float: The total norm of given gradients
"""
if len(gradients) == 0:
return 0.0
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
total_norm = total_norm_cuda.item()
else:
total_norm_exponentiated = 0.0
for grad in gradients:
grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type
total_norm_exponentiated += grad_norm_exponentiated
# Sum across all model parallel GPUs.
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
torch.distributed.all_reduce(
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
)
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
return total_norm
#############################
# Mixed Precision Utilities #
#############################