mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
[feature] Add clip_grad_norm for hybrid_parallel_plugin (#4837)
* Add clip_grad_norm for hibrid_parallel_plugin * polish code * add unittests * Move tp to a higher-level optimizer interface. * bug fix * polish code
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import Tensor, inf
|
||||
from torch.nn import Module, Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
||||
@@ -68,8 +68,6 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
||||
self.mixed_precision = BF16MixedPrecisionMixin()
|
||||
else:
|
||||
raise ValueError(f"Unsupported precision: {precision}")
|
||||
if max_norm > 0.0:
|
||||
raise NotImplementedError("max_norm is not supported yet.")
|
||||
self.max_norm = max_norm
|
||||
self.working_to_master_map: Dict[Parameter, Tensor] = {}
|
||||
self.master_to_working_map: Dict[Tensor, Parameter] = {}
|
||||
@@ -102,32 +100,65 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
||||
return super().zero_grad(*args, **kwargs)
|
||||
|
||||
def _unscale_and_clip_grads(self, total_norm: float) -> None:
|
||||
"""
|
||||
Unscale and clip gradients before performing the optimization step.
|
||||
|
||||
Args:
|
||||
total_norm (float): The computed total gradient norm.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
div_scale = 1.0
|
||||
|
||||
# If mixed-precision training is used, get the gradient division scale from the mixed-precision handler.
|
||||
if self.mixed_precision is not None:
|
||||
div_scale = self.mixed_precision.get_grad_div_scale()
|
||||
|
||||
if self.max_norm > 0.0:
|
||||
# norm is in fact norm*scale
|
||||
# Calculate the scaling factor for gradient clipping
|
||||
# The gradient norm is scaled by 'div_scale' and then clipped to 'max_norm'
|
||||
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
|
||||
|
||||
# If the clip factor exceeds 1, adjust 'div_scale' accordingly to ensure clipping
|
||||
if clip > 1:
|
||||
div_scale = clip * div_scale
|
||||
|
||||
# Apply the scaling factor to gradients
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
p.grad.data.mul_(1.0 / div_scale)
|
||||
|
||||
def _compute_grad_norm(self) -> float:
|
||||
if self.max_norm <= 0.0:
|
||||
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
|
||||
r"""
|
||||
Compute and return the gradient norm for gradient clipping.
|
||||
|
||||
Args:
|
||||
param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation.
|
||||
norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2.
|
||||
|
||||
Returns:
|
||||
float: The total norm of the given gradients.
|
||||
"""
|
||||
|
||||
if len(param_gradient_pairs) == 0:
|
||||
return 0.0
|
||||
grads = [p.grad for group in self.param_groups for p in group["params"] if p.grad is not None]
|
||||
if len(grads) == 0:
|
||||
return 0.0
|
||||
device = grads[0].device
|
||||
# TODO(ver217): support tp
|
||||
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
|
||||
return total_norm.item()
|
||||
|
||||
# gradients used for norm calculation.
|
||||
gradients = [grad for param, grad in param_gradient_pairs]
|
||||
|
||||
if norm_type == inf:
|
||||
total_norm = max(grad.data.abs().max() for grad in gradients)
|
||||
|
||||
else:
|
||||
total_norm_exponentiated = 0.0
|
||||
for grad in gradients:
|
||||
total_norm_exponentiated += grad.data.double().norm(norm_type) ** norm_type
|
||||
total_norm = total_norm_exponentiated ** (1.0 / norm_type)
|
||||
|
||||
return total_norm
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
if self.mixed_precision.should_skip_step():
|
||||
@@ -142,8 +173,22 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
||||
if working_param.grad is not None:
|
||||
p.grad = working_param.grad.data.float()
|
||||
working_param.grad = None
|
||||
total_norm = self._compute_grad_norm()
|
||||
|
||||
# gradient unscale and clip.
|
||||
if self.max_norm <= 0:
|
||||
# no need to compute gradient norm.
|
||||
total_norm = 0.0
|
||||
else:
|
||||
# compute the total norm.
|
||||
param_gradient_pairs = [
|
||||
(self.master_to_working_map[p], p.grad)
|
||||
for group in self.param_groups
|
||||
for p in group["params"]
|
||||
if p.grad is not None
|
||||
]
|
||||
total_norm = self._compute_grad_norm(param_gradient_pairs)
|
||||
self._unscale_and_clip_grads(total_norm)
|
||||
|
||||
self.optim.step(*args, **kwargs)
|
||||
# update working params
|
||||
for group in self.optim.param_groups:
|
||||
|
@@ -1,3 +1,4 @@
|
||||
import ctypes
|
||||
import random
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
@@ -7,7 +8,8 @@ from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple,
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch import Tensor, inf
|
||||
from torch.distributed import ProcessGroup, get_world_size
|
||||
from torch.nn import Module, SyncBatchNorm
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
@@ -24,6 +26,7 @@ from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.tensor.d_tensor.api import is_distributed_tensor
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
|
||||
from .pp_plugin_base import PipelinePluginBase
|
||||
@@ -160,12 +163,143 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module):
|
||||
|
||||
|
||||
class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict):
|
||||
def __init__(
|
||||
self,
|
||||
optim: Optimizer,
|
||||
model: Module,
|
||||
use_pipeline: bool,
|
||||
param_info: OrderedDict,
|
||||
max_norm: float = 0,
|
||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
pp_process_group: Optional[ProcessGroup] = None, # if using pp
|
||||
):
|
||||
self.param_info = param_info
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optim, model)
|
||||
self.stage_manager = model.stage_manager
|
||||
self.shared_params = model.shared_params
|
||||
self.max_norm = max_norm
|
||||
self.tp_pg = tp_process_group
|
||||
self.pp_pg = pp_process_group
|
||||
super().__init__(optim)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
r"""
|
||||
Perform an optimization step.
|
||||
|
||||
Args:
|
||||
*args: Variable-length positional arguments to be passed to the optimizer's step function.
|
||||
**kwargs: Keyword arguments to be passed to the optimizer's step function.
|
||||
"""
|
||||
|
||||
if self.max_norm > 0:
|
||||
# Compute the total gradient norm.
|
||||
param_gradient_pairs = [
|
||||
(p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None
|
||||
]
|
||||
total_norm = self._compute_grad_norm(param_gradient_pairs)
|
||||
|
||||
# Clip the gradients to prevent exploding gradients.
|
||||
self._clip_grad_norm(total_norm)
|
||||
|
||||
# Perform the optimization step using the underlying optimizer.
|
||||
self.optim.step(*args, **kwargs)
|
||||
|
||||
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
|
||||
r"""
|
||||
Compute and return the gradient norm for gradient clipping.
|
||||
|
||||
Args:
|
||||
param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation.
|
||||
norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2.
|
||||
|
||||
Returns:
|
||||
float: The total norm of the given gradients.
|
||||
"""
|
||||
|
||||
if len(param_gradient_pairs) == 0:
|
||||
return 0.0
|
||||
|
||||
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||
norm_type = float(norm_type)
|
||||
|
||||
# gradients used for norm calculation.
|
||||
gradients = [grad for param, grad in param_gradient_pairs]
|
||||
|
||||
if norm_type == inf:
|
||||
total_norm = max(grad.data.abs().max() for grad in gradients)
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
if tp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||
if pp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
|
||||
total_norm = total_norm_cuda.item()
|
||||
else:
|
||||
# gradients used for norm calculation.
|
||||
gradients = [grad for param, grad in param_gradient_pairs]
|
||||
# grad_to_param_mapping is used to check which gradients are not distributed across devices of the 'tp_group'.
|
||||
grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs}
|
||||
|
||||
total_norm_exponentiated = 0.0
|
||||
for grad in gradients:
|
||||
grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type
|
||||
|
||||
# If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor,
|
||||
# it indicates that the parameter is not distributed across devices of the 'tp_group'.
|
||||
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
|
||||
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
|
||||
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
|
||||
if tp_size > 1:
|
||||
param_for_grad = grad_to_param_mapping[id(grad)]
|
||||
if not is_distributed_tensor(param_for_grad):
|
||||
grad_norm_exponentiated /= tp_size
|
||||
|
||||
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
|
||||
# it means that this parameter is used in two different pipeline stages.
|
||||
# To avoid redundant norm calculations, we divide the exponent of this norm by
|
||||
# the number of shared stages.
|
||||
if pp_size > 1:
|
||||
for shared_param in self.shared_params:
|
||||
if self.stage_manager.stage in shared_param:
|
||||
stage_shared_param = shared_param[self.stage_manager.stage]
|
||||
if grad is stage_shared_param.grad:
|
||||
grad_norm_exponentiated /= len(shared_param)
|
||||
|
||||
total_norm_exponentiated += grad_norm_exponentiated
|
||||
|
||||
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
||||
if tp_size > 1:
|
||||
# compute norm in tp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
||||
if pp_size > 1:
|
||||
# compute norm in pp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
|
||||
|
||||
# compute the total_norm
|
||||
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
|
||||
|
||||
return total_norm
|
||||
|
||||
def _clip_grad_norm(self, total_norm: float) -> None:
|
||||
r"""
|
||||
Clips the gradients of the model's parameters to prevent exploding gradients.
|
||||
|
||||
Args:
|
||||
total_norm (float): The computed total gradient norm.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
clip_coef = torch.tensor(self.max_norm / (total_norm + 1e-6))
|
||||
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
|
||||
|
||||
for group in self.optim.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
p.grad.data.mul_(clip_coef_clamped)
|
||||
|
||||
def update_master_params(self, model: Module):
|
||||
pass
|
||||
|
||||
@@ -192,23 +326,108 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0,
|
||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
pp_process_group: Optional[ProcessGroup] = None, # if using pp
|
||||
):
|
||||
self.param_info = param_info
|
||||
self.stage_manager = model.stage_manager
|
||||
self.shared_params = model.shared_params
|
||||
self.tp_pg = tp_process_group
|
||||
self.pp_pg = pp_process_group
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optim, model)
|
||||
super().__init__(
|
||||
optim,
|
||||
precision,
|
||||
initial_scale,
|
||||
min_scale,
|
||||
growth_factor,
|
||||
backoff_factor,
|
||||
growth_interval,
|
||||
hysteresis,
|
||||
max_scale,
|
||||
max_norm,
|
||||
precision=precision,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale,
|
||||
max_norm=max_norm,
|
||||
)
|
||||
|
||||
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
|
||||
r"""
|
||||
Compute and return the gradient norm for gradient clipping.
|
||||
|
||||
Args:
|
||||
param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation.
|
||||
norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2.
|
||||
|
||||
Returns:
|
||||
float: The total norm of the given gradients.
|
||||
"""
|
||||
if len(param_gradient_pairs) == 0:
|
||||
return 0.0
|
||||
|
||||
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||
norm_type = float(norm_type)
|
||||
|
||||
if norm_type == inf:
|
||||
# The parent class calculates the norm of 'dp' gradients,
|
||||
# so we need to calculate the norm of 'tp' and 'pp' gradients.
|
||||
total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type)
|
||||
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
|
||||
if tp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||
if pp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
|
||||
|
||||
total_norm = total_norm_cuda.item()
|
||||
|
||||
else:
|
||||
# gradients used for norm calculation.
|
||||
gradients = [grad for param, grad in param_gradient_pairs]
|
||||
# grad_to_param_mapping is used to check which gradients are not distributed in tensor parallelism.
|
||||
grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs}
|
||||
|
||||
total_norm_exponentiated = 0.0
|
||||
for grad in gradients:
|
||||
grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type
|
||||
|
||||
# If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor,
|
||||
# it indicates that the parameter is not distributed across devices of the 'tp_group'.
|
||||
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
|
||||
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
|
||||
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
|
||||
if tp_size > 1:
|
||||
param_for_grad = grad_to_param_mapping[id(grad)]
|
||||
if not is_distributed_tensor(param_for_grad):
|
||||
grad_norm_exponentiated /= tp_size
|
||||
|
||||
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
|
||||
# it means that this parameter is used in two different pipeline stages.
|
||||
# To avoid redundant norm calculations, we divide the exponent of this norm by
|
||||
# the number of shared stages.
|
||||
if pp_size > 1:
|
||||
for shared_param in self.shared_params:
|
||||
if self.stage_manager.stage in shared_param:
|
||||
stage_working_shared_param = shared_param[self.stage_manager.stage]
|
||||
stage_master_shared_param = self.working_to_master_map[stage_working_shared_param]
|
||||
if grad is stage_master_shared_param.grad:
|
||||
grad_norm_exponentiated /= len(shared_param)
|
||||
|
||||
total_norm_exponentiated += grad_norm_exponentiated
|
||||
|
||||
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
||||
if tp_size > 1:
|
||||
# compute norm in tp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
||||
if pp_size > 1:
|
||||
# compute norm in pp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
|
||||
|
||||
# compute the total_norm
|
||||
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
|
||||
|
||||
return total_norm
|
||||
|
||||
|
||||
class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
def __init__(
|
||||
@@ -233,9 +452,15 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
cpu_offload: bool = False, # cpu offload
|
||||
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
pp_process_group: Optional[ProcessGroup] = None, # if using pp
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
self.param_info = param_info
|
||||
self.stage_manager = model.stage_manager
|
||||
self.shared_params = model.shared_params
|
||||
self.dp_pg = dp_process_group
|
||||
self.tp_pg = tp_process_group
|
||||
self.pp_pg = pp_process_group
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optimizer, model)
|
||||
super().__init__(
|
||||
@@ -255,10 +480,90 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
partition_grad,
|
||||
cpu_offload,
|
||||
dp_process_group,
|
||||
tp_process_group,
|
||||
forced_dtype,
|
||||
)
|
||||
|
||||
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||
r"""
|
||||
Compute and return the gradient norm for gradient clipping.
|
||||
|
||||
Args:
|
||||
gradients (List[Tensor]): A list of tensors containing gradients.
|
||||
norm_type (int, optional): Type of the p-norm to be computed. Defaults to 2.
|
||||
|
||||
Returns:
|
||||
float: The computed gradient norm.
|
||||
"""
|
||||
|
||||
# Check if the list of gradients is empty
|
||||
if len(gradients) == 0:
|
||||
return 0.0
|
||||
|
||||
dp_size = get_world_size(self.dp_pg) if self.dp_pg is not None else 1
|
||||
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||
norm_type = float(norm_type)
|
||||
|
||||
if norm_type == inf:
|
||||
# The parent class calculates the norm of 'dp' gradients,
|
||||
# so we only need to calculate the norm 'tp' of 'pp' gradients.
|
||||
total_norm = super()._compute_grad_norm(gradients, norm_type)
|
||||
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
|
||||
if tp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||
if pp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
|
||||
|
||||
total_norm = total_norm_cuda.item()
|
||||
else:
|
||||
total_norm_exponentiated = 0.0
|
||||
for grad in gradients:
|
||||
grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type
|
||||
|
||||
# If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor,
|
||||
# it indicates that the parameter is not distributed across devices of the 'tp_group'.
|
||||
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
|
||||
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
|
||||
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
|
||||
if tp_size > 1:
|
||||
param_id_for_grad = self._grad_store.get_param_id_for_grad(grad)
|
||||
param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value
|
||||
|
||||
if not is_distributed_tensor(param_for_grad):
|
||||
grad_norm_exponentiated /= tp_size
|
||||
|
||||
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
|
||||
# it means that this parameter is used in two different pipeline stages.
|
||||
# To avoid redundant norm calculations, we divide the exponent of this norm by
|
||||
# the number of shared stages.
|
||||
if pp_size > 1:
|
||||
for shared_param in self.shared_params:
|
||||
if self.stage_manager.stage in shared_param:
|
||||
stage_shared_param = shared_param[self.stage_manager.stage]
|
||||
working_grad = self._grad_store.get_working_grad_by_param_id(id(stage_shared_param))
|
||||
if grad is working_grad:
|
||||
grad_norm_exponentiated /= len(shared_param)
|
||||
|
||||
total_norm_exponentiated += grad_norm_exponentiated
|
||||
|
||||
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
||||
if dp_size > 1:
|
||||
# compute norm in dp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg)
|
||||
if tp_size > 1:
|
||||
# compute norm in tp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
||||
if pp_size > 1:
|
||||
# compute norm in pp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
|
||||
|
||||
# Compute the 'total_norm' from 'total_norm_exponentiated'
|
||||
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
|
||||
|
||||
return total_norm
|
||||
|
||||
|
||||
class HybridParallelPlugin(PipelinePluginBase):
|
||||
"""
|
||||
@@ -475,11 +780,19 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
param_info=param_info,
|
||||
precision=self.precision,
|
||||
max_norm=self.max_norm,
|
||||
pp_process_group=self.pp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
**self.amp_config,
|
||||
)
|
||||
else:
|
||||
optimizer = HybridParallelNaiveOptimizer(
|
||||
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
|
||||
optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
max_norm=self.max_norm,
|
||||
pp_process_group=self.pp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
)
|
||||
else:
|
||||
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
|
||||
@@ -491,6 +804,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
param_info=param_info,
|
||||
dp_process_group=self.dp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
pp_process_group=self.pp_group,
|
||||
verbose=True,
|
||||
clip_grad_norm=self.max_norm,
|
||||
**self.zero_config,
|
||||
|
@@ -3,9 +3,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor, inf
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
def flatten(input_):
|
||||
@@ -192,53 +190,6 @@ def calculate_global_norm_from_list(norm_list):
|
||||
total_norm += norm**2.0
|
||||
return math.sqrt(total_norm)
|
||||
|
||||
|
||||
def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGroup, norm_type: int = 2) -> int:
|
||||
"""Clips gradient norm of an iterable of parameters.
|
||||
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
|
||||
added functionality to handle model parallel parameters.
|
||||
|
||||
Args:
|
||||
gradients (Tensor): The gradients to compute norm
|
||||
dp_group (ProcessGroup): The process group of ZeRO Data Parallelism
|
||||
tp_group (ProcessGroup): The process group of Tensor Parallelism
|
||||
norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2.
|
||||
|
||||
Returns:
|
||||
int: The total norm of given gradients
|
||||
"""
|
||||
|
||||
norm_type = float(norm_type)
|
||||
if norm_type == inf:
|
||||
total_norm = max(g.data.abs().max() for g in gradients)
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group)
|
||||
|
||||
# Take max across all GPUs.
|
||||
if tp_group is not None:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX)
|
||||
total_norm = total_norm_cuda[0].item()
|
||||
else:
|
||||
total_norm = 0.0
|
||||
for g in gradients:
|
||||
param_norm = g.data.double().norm(norm_type)
|
||||
total_norm += param_norm.item() ** norm_type
|
||||
|
||||
# Sum across all model parallel GPUs.
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group)
|
||||
|
||||
if tp_group is not None:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=tp_group)
|
||||
|
||||
total_norm = total_norm_cuda[0].item() ** (1.0 / norm_type)
|
||||
|
||||
if total_norm == float("inf") or total_norm == -float("inf") or total_norm != total_norm:
|
||||
total_norm = -1
|
||||
|
||||
return total_norm
|
||||
|
||||
|
||||
def sync_tensor(flat_tensor, tensor_list):
|
||||
"""
|
||||
Synchronize the flattened tensor and unflattened tensor list. When
|
||||
|
@@ -21,6 +21,8 @@ class GradientStore(BaseStore):
|
||||
# for zero2, it's `param_id: [grad_local_rank]`
|
||||
self._working_index = 0 if partition_grad else self._local_rank
|
||||
|
||||
self.grad_to_param_mapping = dict()
|
||||
|
||||
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
|
||||
"""Return list of gradient slices of a specific parameter
|
||||
|
||||
@@ -54,6 +56,8 @@ class GradientStore(BaseStore):
|
||||
else:
|
||||
self._grads_of_params[group_id][param_id].append(grad)
|
||||
|
||||
self.grad_to_param_mapping[id(grad)] = param_id
|
||||
|
||||
def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int):
|
||||
"""Add a gradient slice on an existing slice of the parameter's gradient
|
||||
Used when no_sync is not activated.
|
||||
@@ -83,8 +87,37 @@ class GradientStore(BaseStore):
|
||||
|
||||
return grad_list
|
||||
|
||||
def get_working_grad_by_param_id(self, param_id) -> Tensor:
|
||||
"""
|
||||
Return the working gradient for the specified parameter.
|
||||
|
||||
Args:
|
||||
param_id (int): The index of the parameter.
|
||||
|
||||
Returns:
|
||||
Tensor: The the working gradient slices for the specified param_id.
|
||||
"""
|
||||
|
||||
for group in self._grads_of_params.values():
|
||||
if param_id in group.keys():
|
||||
return group[param_id][self._working_index]
|
||||
|
||||
raise KeyError(f"Working gradient for param_id {param_id} not found.")
|
||||
|
||||
def reset_grads_by_group_id(self, group_id: int):
|
||||
self._grads_of_params[group_id] = dict()
|
||||
|
||||
def reset_all_gradients(self):
|
||||
self._grads_of_params = dict()
|
||||
|
||||
def get_param_id_for_grad(self, grad: Tensor) -> int:
|
||||
"""Return the id of a parameter which the gradient slice belongs to
|
||||
|
||||
Args:
|
||||
grad (Tensor): the gradient slice
|
||||
|
||||
Returns:
|
||||
int: the id of a parameter which the gradient slice belongs to
|
||||
"""
|
||||
|
||||
return self.grad_to_param_mapping[id(grad)]
|
||||
|
@@ -2,11 +2,12 @@
|
||||
import copy
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from typing import Dict, Iterator, Optional, Tuple
|
||||
from typing import Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch import Tensor, inf
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.optim import Optimizer
|
||||
|
||||
@@ -21,14 +22,7 @@ from colossalai.logging import get_dist_logger
|
||||
# from colossalai.tensor import ColoParameter, ProcessGroup
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from ._utils import (
|
||||
calculate_global_norm_from_list,
|
||||
compute_norm,
|
||||
flatten,
|
||||
has_inf_or_nan,
|
||||
release_param_grad,
|
||||
sync_tensor,
|
||||
)
|
||||
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
|
||||
from .bookkeeping import BucketStore, GradientStore, ParameterStore
|
||||
|
||||
|
||||
@@ -80,7 +74,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
partition_grad: bool = False, # stage 2 flag
|
||||
cpu_offload: bool = False, # cpu offload
|
||||
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
||||
@@ -101,8 +94,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
self._local_rank = dist.get_rank(group=self.dp_pg)
|
||||
self._world_size = dist.get_world_size(group=self.dp_pg)
|
||||
|
||||
self.tp_pg = tp_process_group
|
||||
|
||||
# working and master params for mixed precision training
|
||||
self._working_param_groups = dict()
|
||||
self._master_param_groups_of_current_rank = dict()
|
||||
@@ -433,7 +424,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
# compute norm
|
||||
working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
|
||||
norm_group = compute_norm(gradients=working_grads, dp_group=self.dp_pg, tp_group=self.tp_pg)
|
||||
norm_group = self._compute_grad_norm(gradients=working_grads)
|
||||
norm_groups.append(norm_group)
|
||||
|
||||
self._grad_store.reset_grads_by_group_id(group_id)
|
||||
@@ -467,6 +458,44 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
||||
|
||||
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||
r"""
|
||||
Compute and return the gradient norm for gradient clipping.
|
||||
|
||||
Args:
|
||||
gradients (List[Tensor]): The gradients to compute norm
|
||||
norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2.
|
||||
|
||||
Returns:
|
||||
float: The total norm of given gradients
|
||||
"""
|
||||
|
||||
if len(gradients) == 0:
|
||||
return 0.0
|
||||
|
||||
norm_type = float(norm_type)
|
||||
if norm_type == inf:
|
||||
total_norm = max(grad.data.abs().max() for grad in gradients)
|
||||
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
|
||||
total_norm = total_norm_cuda.item()
|
||||
|
||||
else:
|
||||
total_norm_exponentiated = 0.0
|
||||
for grad in gradients:
|
||||
grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type
|
||||
total_norm_exponentiated += grad_norm_exponentiated
|
||||
|
||||
# Sum across all model parallel GPUs.
|
||||
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
||||
torch.distributed.all_reduce(
|
||||
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
|
||||
)
|
||||
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
|
||||
|
||||
return total_norm
|
||||
|
||||
#############################
|
||||
# Mixed Precision Utilities #
|
||||
#############################
|
||||
|
Reference in New Issue
Block a user