mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[feature] Add clip_grad_norm for hybrid_parallel_plugin (#4837)
* Add clip_grad_norm for hibrid_parallel_plugin * polish code * add unittests * Move tp to a higher-level optimizer interface. * bug fix * polish code
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import ctypes
|
||||
import random
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
@@ -7,7 +8,8 @@ from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple,
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch import Tensor, inf
|
||||
from torch.distributed import ProcessGroup, get_world_size
|
||||
from torch.nn import Module, SyncBatchNorm
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
@@ -24,6 +26,7 @@ from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.tensor.d_tensor.api import is_distributed_tensor
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
|
||||
from .pp_plugin_base import PipelinePluginBase
|
||||
@@ -160,12 +163,143 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module):
|
||||
|
||||
|
||||
class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict):
|
||||
def __init__(
|
||||
self,
|
||||
optim: Optimizer,
|
||||
model: Module,
|
||||
use_pipeline: bool,
|
||||
param_info: OrderedDict,
|
||||
max_norm: float = 0,
|
||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
pp_process_group: Optional[ProcessGroup] = None, # if using pp
|
||||
):
|
||||
self.param_info = param_info
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optim, model)
|
||||
self.stage_manager = model.stage_manager
|
||||
self.shared_params = model.shared_params
|
||||
self.max_norm = max_norm
|
||||
self.tp_pg = tp_process_group
|
||||
self.pp_pg = pp_process_group
|
||||
super().__init__(optim)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
r"""
|
||||
Perform an optimization step.
|
||||
|
||||
Args:
|
||||
*args: Variable-length positional arguments to be passed to the optimizer's step function.
|
||||
**kwargs: Keyword arguments to be passed to the optimizer's step function.
|
||||
"""
|
||||
|
||||
if self.max_norm > 0:
|
||||
# Compute the total gradient norm.
|
||||
param_gradient_pairs = [
|
||||
(p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None
|
||||
]
|
||||
total_norm = self._compute_grad_norm(param_gradient_pairs)
|
||||
|
||||
# Clip the gradients to prevent exploding gradients.
|
||||
self._clip_grad_norm(total_norm)
|
||||
|
||||
# Perform the optimization step using the underlying optimizer.
|
||||
self.optim.step(*args, **kwargs)
|
||||
|
||||
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
|
||||
r"""
|
||||
Compute and return the gradient norm for gradient clipping.
|
||||
|
||||
Args:
|
||||
param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation.
|
||||
norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2.
|
||||
|
||||
Returns:
|
||||
float: The total norm of the given gradients.
|
||||
"""
|
||||
|
||||
if len(param_gradient_pairs) == 0:
|
||||
return 0.0
|
||||
|
||||
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||
norm_type = float(norm_type)
|
||||
|
||||
# gradients used for norm calculation.
|
||||
gradients = [grad for param, grad in param_gradient_pairs]
|
||||
|
||||
if norm_type == inf:
|
||||
total_norm = max(grad.data.abs().max() for grad in gradients)
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
if tp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||
if pp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
|
||||
total_norm = total_norm_cuda.item()
|
||||
else:
|
||||
# gradients used for norm calculation.
|
||||
gradients = [grad for param, grad in param_gradient_pairs]
|
||||
# grad_to_param_mapping is used to check which gradients are not distributed across devices of the 'tp_group'.
|
||||
grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs}
|
||||
|
||||
total_norm_exponentiated = 0.0
|
||||
for grad in gradients:
|
||||
grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type
|
||||
|
||||
# If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor,
|
||||
# it indicates that the parameter is not distributed across devices of the 'tp_group'.
|
||||
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
|
||||
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
|
||||
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
|
||||
if tp_size > 1:
|
||||
param_for_grad = grad_to_param_mapping[id(grad)]
|
||||
if not is_distributed_tensor(param_for_grad):
|
||||
grad_norm_exponentiated /= tp_size
|
||||
|
||||
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
|
||||
# it means that this parameter is used in two different pipeline stages.
|
||||
# To avoid redundant norm calculations, we divide the exponent of this norm by
|
||||
# the number of shared stages.
|
||||
if pp_size > 1:
|
||||
for shared_param in self.shared_params:
|
||||
if self.stage_manager.stage in shared_param:
|
||||
stage_shared_param = shared_param[self.stage_manager.stage]
|
||||
if grad is stage_shared_param.grad:
|
||||
grad_norm_exponentiated /= len(shared_param)
|
||||
|
||||
total_norm_exponentiated += grad_norm_exponentiated
|
||||
|
||||
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
||||
if tp_size > 1:
|
||||
# compute norm in tp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
||||
if pp_size > 1:
|
||||
# compute norm in pp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
|
||||
|
||||
# compute the total_norm
|
||||
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
|
||||
|
||||
return total_norm
|
||||
|
||||
def _clip_grad_norm(self, total_norm: float) -> None:
|
||||
r"""
|
||||
Clips the gradients of the model's parameters to prevent exploding gradients.
|
||||
|
||||
Args:
|
||||
total_norm (float): The computed total gradient norm.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
clip_coef = torch.tensor(self.max_norm / (total_norm + 1e-6))
|
||||
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
|
||||
|
||||
for group in self.optim.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
p.grad.data.mul_(clip_coef_clamped)
|
||||
|
||||
def update_master_params(self, model: Module):
|
||||
pass
|
||||
|
||||
@@ -192,23 +326,108 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0,
|
||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
pp_process_group: Optional[ProcessGroup] = None, # if using pp
|
||||
):
|
||||
self.param_info = param_info
|
||||
self.stage_manager = model.stage_manager
|
||||
self.shared_params = model.shared_params
|
||||
self.tp_pg = tp_process_group
|
||||
self.pp_pg = pp_process_group
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optim, model)
|
||||
super().__init__(
|
||||
optim,
|
||||
precision,
|
||||
initial_scale,
|
||||
min_scale,
|
||||
growth_factor,
|
||||
backoff_factor,
|
||||
growth_interval,
|
||||
hysteresis,
|
||||
max_scale,
|
||||
max_norm,
|
||||
precision=precision,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale,
|
||||
max_norm=max_norm,
|
||||
)
|
||||
|
||||
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
|
||||
r"""
|
||||
Compute and return the gradient norm for gradient clipping.
|
||||
|
||||
Args:
|
||||
param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation.
|
||||
norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2.
|
||||
|
||||
Returns:
|
||||
float: The total norm of the given gradients.
|
||||
"""
|
||||
if len(param_gradient_pairs) == 0:
|
||||
return 0.0
|
||||
|
||||
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||
norm_type = float(norm_type)
|
||||
|
||||
if norm_type == inf:
|
||||
# The parent class calculates the norm of 'dp' gradients,
|
||||
# so we need to calculate the norm of 'tp' and 'pp' gradients.
|
||||
total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type)
|
||||
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
|
||||
if tp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||
if pp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
|
||||
|
||||
total_norm = total_norm_cuda.item()
|
||||
|
||||
else:
|
||||
# gradients used for norm calculation.
|
||||
gradients = [grad for param, grad in param_gradient_pairs]
|
||||
# grad_to_param_mapping is used to check which gradients are not distributed in tensor parallelism.
|
||||
grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs}
|
||||
|
||||
total_norm_exponentiated = 0.0
|
||||
for grad in gradients:
|
||||
grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type
|
||||
|
||||
# If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor,
|
||||
# it indicates that the parameter is not distributed across devices of the 'tp_group'.
|
||||
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
|
||||
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
|
||||
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
|
||||
if tp_size > 1:
|
||||
param_for_grad = grad_to_param_mapping[id(grad)]
|
||||
if not is_distributed_tensor(param_for_grad):
|
||||
grad_norm_exponentiated /= tp_size
|
||||
|
||||
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
|
||||
# it means that this parameter is used in two different pipeline stages.
|
||||
# To avoid redundant norm calculations, we divide the exponent of this norm by
|
||||
# the number of shared stages.
|
||||
if pp_size > 1:
|
||||
for shared_param in self.shared_params:
|
||||
if self.stage_manager.stage in shared_param:
|
||||
stage_working_shared_param = shared_param[self.stage_manager.stage]
|
||||
stage_master_shared_param = self.working_to_master_map[stage_working_shared_param]
|
||||
if grad is stage_master_shared_param.grad:
|
||||
grad_norm_exponentiated /= len(shared_param)
|
||||
|
||||
total_norm_exponentiated += grad_norm_exponentiated
|
||||
|
||||
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
||||
if tp_size > 1:
|
||||
# compute norm in tp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
||||
if pp_size > 1:
|
||||
# compute norm in pp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
|
||||
|
||||
# compute the total_norm
|
||||
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
|
||||
|
||||
return total_norm
|
||||
|
||||
|
||||
class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
def __init__(
|
||||
@@ -233,9 +452,15 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
cpu_offload: bool = False, # cpu offload
|
||||
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
pp_process_group: Optional[ProcessGroup] = None, # if using pp
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
self.param_info = param_info
|
||||
self.stage_manager = model.stage_manager
|
||||
self.shared_params = model.shared_params
|
||||
self.dp_pg = dp_process_group
|
||||
self.tp_pg = tp_process_group
|
||||
self.pp_pg = pp_process_group
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optimizer, model)
|
||||
super().__init__(
|
||||
@@ -255,10 +480,90 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
partition_grad,
|
||||
cpu_offload,
|
||||
dp_process_group,
|
||||
tp_process_group,
|
||||
forced_dtype,
|
||||
)
|
||||
|
||||
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||
r"""
|
||||
Compute and return the gradient norm for gradient clipping.
|
||||
|
||||
Args:
|
||||
gradients (List[Tensor]): A list of tensors containing gradients.
|
||||
norm_type (int, optional): Type of the p-norm to be computed. Defaults to 2.
|
||||
|
||||
Returns:
|
||||
float: The computed gradient norm.
|
||||
"""
|
||||
|
||||
# Check if the list of gradients is empty
|
||||
if len(gradients) == 0:
|
||||
return 0.0
|
||||
|
||||
dp_size = get_world_size(self.dp_pg) if self.dp_pg is not None else 1
|
||||
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||
norm_type = float(norm_type)
|
||||
|
||||
if norm_type == inf:
|
||||
# The parent class calculates the norm of 'dp' gradients,
|
||||
# so we only need to calculate the norm 'tp' of 'pp' gradients.
|
||||
total_norm = super()._compute_grad_norm(gradients, norm_type)
|
||||
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
|
||||
if tp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||
if pp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
|
||||
|
||||
total_norm = total_norm_cuda.item()
|
||||
else:
|
||||
total_norm_exponentiated = 0.0
|
||||
for grad in gradients:
|
||||
grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type
|
||||
|
||||
# If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor,
|
||||
# it indicates that the parameter is not distributed across devices of the 'tp_group'.
|
||||
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
|
||||
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
|
||||
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
|
||||
if tp_size > 1:
|
||||
param_id_for_grad = self._grad_store.get_param_id_for_grad(grad)
|
||||
param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value
|
||||
|
||||
if not is_distributed_tensor(param_for_grad):
|
||||
grad_norm_exponentiated /= tp_size
|
||||
|
||||
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
|
||||
# it means that this parameter is used in two different pipeline stages.
|
||||
# To avoid redundant norm calculations, we divide the exponent of this norm by
|
||||
# the number of shared stages.
|
||||
if pp_size > 1:
|
||||
for shared_param in self.shared_params:
|
||||
if self.stage_manager.stage in shared_param:
|
||||
stage_shared_param = shared_param[self.stage_manager.stage]
|
||||
working_grad = self._grad_store.get_working_grad_by_param_id(id(stage_shared_param))
|
||||
if grad is working_grad:
|
||||
grad_norm_exponentiated /= len(shared_param)
|
||||
|
||||
total_norm_exponentiated += grad_norm_exponentiated
|
||||
|
||||
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
||||
if dp_size > 1:
|
||||
# compute norm in dp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg)
|
||||
if tp_size > 1:
|
||||
# compute norm in tp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
||||
if pp_size > 1:
|
||||
# compute norm in pp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
|
||||
|
||||
# Compute the 'total_norm' from 'total_norm_exponentiated'
|
||||
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
|
||||
|
||||
return total_norm
|
||||
|
||||
|
||||
class HybridParallelPlugin(PipelinePluginBase):
|
||||
"""
|
||||
@@ -475,11 +780,19 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
param_info=param_info,
|
||||
precision=self.precision,
|
||||
max_norm=self.max_norm,
|
||||
pp_process_group=self.pp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
**self.amp_config,
|
||||
)
|
||||
else:
|
||||
optimizer = HybridParallelNaiveOptimizer(
|
||||
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
|
||||
optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
max_norm=self.max_norm,
|
||||
pp_process_group=self.pp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
)
|
||||
else:
|
||||
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
|
||||
@@ -491,6 +804,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
param_info=param_info,
|
||||
dp_process_group=self.dp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
pp_process_group=self.pp_group,
|
||||
verbose=True,
|
||||
clip_grad_norm=self.max_norm,
|
||||
**self.zero_config,
|
||||
|
Reference in New Issue
Block a user