mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-31 16:40:41 +00:00
[feature] Add clip_grad_norm for hybrid_parallel_plugin (#4837)
* Add clip_grad_norm for hibrid_parallel_plugin * polish code * add unittests * Move tp to a higher-level optimizer interface. * bug fix * polish code
This commit is contained in:
@@ -3,9 +3,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor, inf
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
def flatten(input_):
|
||||
@@ -192,53 +190,6 @@ def calculate_global_norm_from_list(norm_list):
|
||||
total_norm += norm**2.0
|
||||
return math.sqrt(total_norm)
|
||||
|
||||
|
||||
def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGroup, norm_type: int = 2) -> int:
|
||||
"""Clips gradient norm of an iterable of parameters.
|
||||
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
|
||||
added functionality to handle model parallel parameters.
|
||||
|
||||
Args:
|
||||
gradients (Tensor): The gradients to compute norm
|
||||
dp_group (ProcessGroup): The process group of ZeRO Data Parallelism
|
||||
tp_group (ProcessGroup): The process group of Tensor Parallelism
|
||||
norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2.
|
||||
|
||||
Returns:
|
||||
int: The total norm of given gradients
|
||||
"""
|
||||
|
||||
norm_type = float(norm_type)
|
||||
if norm_type == inf:
|
||||
total_norm = max(g.data.abs().max() for g in gradients)
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group)
|
||||
|
||||
# Take max across all GPUs.
|
||||
if tp_group is not None:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX)
|
||||
total_norm = total_norm_cuda[0].item()
|
||||
else:
|
||||
total_norm = 0.0
|
||||
for g in gradients:
|
||||
param_norm = g.data.double().norm(norm_type)
|
||||
total_norm += param_norm.item() ** norm_type
|
||||
|
||||
# Sum across all model parallel GPUs.
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group)
|
||||
|
||||
if tp_group is not None:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=tp_group)
|
||||
|
||||
total_norm = total_norm_cuda[0].item() ** (1.0 / norm_type)
|
||||
|
||||
if total_norm == float("inf") or total_norm == -float("inf") or total_norm != total_norm:
|
||||
total_norm = -1
|
||||
|
||||
return total_norm
|
||||
|
||||
|
||||
def sync_tensor(flat_tensor, tensor_list):
|
||||
"""
|
||||
Synchronize the flattened tensor and unflattened tensor list. When
|
||||
|
@@ -21,6 +21,8 @@ class GradientStore(BaseStore):
|
||||
# for zero2, it's `param_id: [grad_local_rank]`
|
||||
self._working_index = 0 if partition_grad else self._local_rank
|
||||
|
||||
self.grad_to_param_mapping = dict()
|
||||
|
||||
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
|
||||
"""Return list of gradient slices of a specific parameter
|
||||
|
||||
@@ -54,6 +56,8 @@ class GradientStore(BaseStore):
|
||||
else:
|
||||
self._grads_of_params[group_id][param_id].append(grad)
|
||||
|
||||
self.grad_to_param_mapping[id(grad)] = param_id
|
||||
|
||||
def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int):
|
||||
"""Add a gradient slice on an existing slice of the parameter's gradient
|
||||
Used when no_sync is not activated.
|
||||
@@ -83,8 +87,37 @@ class GradientStore(BaseStore):
|
||||
|
||||
return grad_list
|
||||
|
||||
def get_working_grad_by_param_id(self, param_id) -> Tensor:
|
||||
"""
|
||||
Return the working gradient for the specified parameter.
|
||||
|
||||
Args:
|
||||
param_id (int): The index of the parameter.
|
||||
|
||||
Returns:
|
||||
Tensor: The the working gradient slices for the specified param_id.
|
||||
"""
|
||||
|
||||
for group in self._grads_of_params.values():
|
||||
if param_id in group.keys():
|
||||
return group[param_id][self._working_index]
|
||||
|
||||
raise KeyError(f"Working gradient for param_id {param_id} not found.")
|
||||
|
||||
def reset_grads_by_group_id(self, group_id: int):
|
||||
self._grads_of_params[group_id] = dict()
|
||||
|
||||
def reset_all_gradients(self):
|
||||
self._grads_of_params = dict()
|
||||
|
||||
def get_param_id_for_grad(self, grad: Tensor) -> int:
|
||||
"""Return the id of a parameter which the gradient slice belongs to
|
||||
|
||||
Args:
|
||||
grad (Tensor): the gradient slice
|
||||
|
||||
Returns:
|
||||
int: the id of a parameter which the gradient slice belongs to
|
||||
"""
|
||||
|
||||
return self.grad_to_param_mapping[id(grad)]
|
||||
|
@@ -2,11 +2,12 @@
|
||||
import copy
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from typing import Dict, Iterator, Optional, Tuple
|
||||
from typing import Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch import Tensor, inf
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.optim import Optimizer
|
||||
|
||||
@@ -21,14 +22,7 @@ from colossalai.logging import get_dist_logger
|
||||
# from colossalai.tensor import ColoParameter, ProcessGroup
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from ._utils import (
|
||||
calculate_global_norm_from_list,
|
||||
compute_norm,
|
||||
flatten,
|
||||
has_inf_or_nan,
|
||||
release_param_grad,
|
||||
sync_tensor,
|
||||
)
|
||||
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
|
||||
from .bookkeeping import BucketStore, GradientStore, ParameterStore
|
||||
|
||||
|
||||
@@ -80,7 +74,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
partition_grad: bool = False, # stage 2 flag
|
||||
cpu_offload: bool = False, # cpu offload
|
||||
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
||||
@@ -101,8 +94,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
self._local_rank = dist.get_rank(group=self.dp_pg)
|
||||
self._world_size = dist.get_world_size(group=self.dp_pg)
|
||||
|
||||
self.tp_pg = tp_process_group
|
||||
|
||||
# working and master params for mixed precision training
|
||||
self._working_param_groups = dict()
|
||||
self._master_param_groups_of_current_rank = dict()
|
||||
@@ -433,7 +424,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
# compute norm
|
||||
working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
|
||||
norm_group = compute_norm(gradients=working_grads, dp_group=self.dp_pg, tp_group=self.tp_pg)
|
||||
norm_group = self._compute_grad_norm(gradients=working_grads)
|
||||
norm_groups.append(norm_group)
|
||||
|
||||
self._grad_store.reset_grads_by_group_id(group_id)
|
||||
@@ -467,6 +458,44 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
||||
|
||||
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||
r"""
|
||||
Compute and return the gradient norm for gradient clipping.
|
||||
|
||||
Args:
|
||||
gradients (List[Tensor]): The gradients to compute norm
|
||||
norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2.
|
||||
|
||||
Returns:
|
||||
float: The total norm of given gradients
|
||||
"""
|
||||
|
||||
if len(gradients) == 0:
|
||||
return 0.0
|
||||
|
||||
norm_type = float(norm_type)
|
||||
if norm_type == inf:
|
||||
total_norm = max(grad.data.abs().max() for grad in gradients)
|
||||
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
|
||||
total_norm = total_norm_cuda.item()
|
||||
|
||||
else:
|
||||
total_norm_exponentiated = 0.0
|
||||
for grad in gradients:
|
||||
grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type
|
||||
total_norm_exponentiated += grad_norm_exponentiated
|
||||
|
||||
# Sum across all model parallel GPUs.
|
||||
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
||||
torch.distributed.all_reduce(
|
||||
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
|
||||
)
|
||||
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
|
||||
|
||||
return total_norm
|
||||
|
||||
#############################
|
||||
# Mixed Precision Utilities #
|
||||
#############################
|
||||
|
Reference in New Issue
Block a user