[feature] Add clip_grad_norm for hybrid_parallel_plugin (#4837)

* Add clip_grad_norm for hibrid_parallel_plugin

* polish code

* add unittests

* Move tp to a higher-level optimizer interface.

* bug fix

* polish code
This commit is contained in:
littsk
2023-10-12 11:32:37 +08:00
committed by GitHub
parent df63564184
commit 83b52c56cd
8 changed files with 1158 additions and 90 deletions

View File

@@ -1,3 +1,4 @@
import ctypes
import random
from contextlib import nullcontext
from functools import partial
@@ -7,7 +8,8 @@ from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple,
import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch import Tensor, inf
from torch.distributed import ProcessGroup, get_world_size
from torch.nn import Module, SyncBatchNorm
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
@@ -24,6 +26,7 @@ from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.d_tensor.api import is_distributed_tensor
from colossalai.zero.low_level import LowLevelZeroOptimizer
from .pp_plugin_base import PipelinePluginBase
@@ -160,12 +163,143 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module):
class HybridParallelNaiveOptimizer(OptimizerWrapper):
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict):
def __init__(
self,
optim: Optimizer,
model: Module,
use_pipeline: bool,
param_info: OrderedDict,
max_norm: float = 0,
tp_process_group: Optional[ProcessGroup] = None, # if using tp
pp_process_group: Optional[ProcessGroup] = None, # if using pp
):
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optim, model)
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.max_norm = max_norm
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
super().__init__(optim)
def step(self, *args, **kwargs):
r"""
Perform an optimization step.
Args:
*args: Variable-length positional arguments to be passed to the optimizer's step function.
**kwargs: Keyword arguments to be passed to the optimizer's step function.
"""
if self.max_norm > 0:
# Compute the total gradient norm.
param_gradient_pairs = [
(p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None
]
total_norm = self._compute_grad_norm(param_gradient_pairs)
# Clip the gradients to prevent exploding gradients.
self._clip_grad_norm(total_norm)
# Perform the optimization step using the underlying optimizer.
self.optim.step(*args, **kwargs)
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
r"""
Compute and return the gradient norm for gradient clipping.
Args:
param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation.
norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2.
Returns:
float: The total norm of the given gradients.
"""
if len(param_gradient_pairs) == 0:
return 0.0
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
norm_type = float(norm_type)
# gradients used for norm calculation.
gradients = [grad for param, grad in param_gradient_pairs]
if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
if tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
if pp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
total_norm = total_norm_cuda.item()
else:
# gradients used for norm calculation.
gradients = [grad for param, grad in param_gradient_pairs]
# grad_to_param_mapping is used to check which gradients are not distributed across devices of the 'tp_group'.
grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs}
total_norm_exponentiated = 0.0
for grad in gradients:
grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type
# If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor,
# it indicates that the parameter is not distributed across devices of the 'tp_group'.
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
if tp_size > 1:
param_for_grad = grad_to_param_mapping[id(grad)]
if not is_distributed_tensor(param_for_grad):
grad_norm_exponentiated /= tp_size
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
# it means that this parameter is used in two different pipeline stages.
# To avoid redundant norm calculations, we divide the exponent of this norm by
# the number of shared stages.
if pp_size > 1:
for shared_param in self.shared_params:
if self.stage_manager.stage in shared_param:
stage_shared_param = shared_param[self.stage_manager.stage]
if grad is stage_shared_param.grad:
grad_norm_exponentiated /= len(shared_param)
total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
if tp_size > 1:
# compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
if pp_size > 1:
# compute norm in pp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
# compute the total_norm
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
return total_norm
def _clip_grad_norm(self, total_norm: float) -> None:
r"""
Clips the gradients of the model's parameters to prevent exploding gradients.
Args:
total_norm (float): The computed total gradient norm.
Returns:
None
"""
clip_coef = torch.tensor(self.max_norm / (total_norm + 1e-6))
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for group in self.optim.param_groups:
for p in group["params"]:
if p.grad is None:
continue
p.grad.data.mul_(clip_coef_clamped)
def update_master_params(self, model: Module):
pass
@@ -192,23 +326,108 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0,
tp_process_group: Optional[ProcessGroup] = None, # if using tp
pp_process_group: Optional[ProcessGroup] = None, # if using pp
):
self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
if use_pipeline:
init_pipeline_optimizer(optim, model)
super().__init__(
optim,
precision,
initial_scale,
min_scale,
growth_factor,
backoff_factor,
growth_interval,
hysteresis,
max_scale,
max_norm,
precision=precision,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
max_norm=max_norm,
)
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
r"""
Compute and return the gradient norm for gradient clipping.
Args:
param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation.
norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2.
Returns:
float: The total norm of the given gradients.
"""
if len(param_gradient_pairs) == 0:
return 0.0
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
norm_type = float(norm_type)
if norm_type == inf:
# The parent class calculates the norm of 'dp' gradients,
# so we need to calculate the norm of 'tp' and 'pp' gradients.
total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
if tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
if pp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
total_norm = total_norm_cuda.item()
else:
# gradients used for norm calculation.
gradients = [grad for param, grad in param_gradient_pairs]
# grad_to_param_mapping is used to check which gradients are not distributed in tensor parallelism.
grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs}
total_norm_exponentiated = 0.0
for grad in gradients:
grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type
# If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor,
# it indicates that the parameter is not distributed across devices of the 'tp_group'.
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
if tp_size > 1:
param_for_grad = grad_to_param_mapping[id(grad)]
if not is_distributed_tensor(param_for_grad):
grad_norm_exponentiated /= tp_size
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
# it means that this parameter is used in two different pipeline stages.
# To avoid redundant norm calculations, we divide the exponent of this norm by
# the number of shared stages.
if pp_size > 1:
for shared_param in self.shared_params:
if self.stage_manager.stage in shared_param:
stage_working_shared_param = shared_param[self.stage_manager.stage]
stage_master_shared_param = self.working_to_master_map[stage_working_shared_param]
if grad is stage_master_shared_param.grad:
grad_norm_exponentiated /= len(shared_param)
total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
if tp_size > 1:
# compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
if pp_size > 1:
# compute norm in pp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
# compute the total_norm
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
return total_norm
class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
def __init__(
@@ -233,9 +452,15 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp
pp_process_group: Optional[ProcessGroup] = None, # if using pp
forced_dtype: Optional[torch.dtype] = None,
):
self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.dp_pg = dp_process_group
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
if use_pipeline:
init_pipeline_optimizer(optimizer, model)
super().__init__(
@@ -255,10 +480,90 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
partition_grad,
cpu_offload,
dp_process_group,
tp_process_group,
forced_dtype,
)
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
r"""
Compute and return the gradient norm for gradient clipping.
Args:
gradients (List[Tensor]): A list of tensors containing gradients.
norm_type (int, optional): Type of the p-norm to be computed. Defaults to 2.
Returns:
float: The computed gradient norm.
"""
# Check if the list of gradients is empty
if len(gradients) == 0:
return 0.0
dp_size = get_world_size(self.dp_pg) if self.dp_pg is not None else 1
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
norm_type = float(norm_type)
if norm_type == inf:
# The parent class calculates the norm of 'dp' gradients,
# so we only need to calculate the norm 'tp' of 'pp' gradients.
total_norm = super()._compute_grad_norm(gradients, norm_type)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
if tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
if pp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
total_norm = total_norm_cuda.item()
else:
total_norm_exponentiated = 0.0
for grad in gradients:
grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type
# If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor,
# it indicates that the parameter is not distributed across devices of the 'tp_group'.
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
if tp_size > 1:
param_id_for_grad = self._grad_store.get_param_id_for_grad(grad)
param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value
if not is_distributed_tensor(param_for_grad):
grad_norm_exponentiated /= tp_size
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
# it means that this parameter is used in two different pipeline stages.
# To avoid redundant norm calculations, we divide the exponent of this norm by
# the number of shared stages.
if pp_size > 1:
for shared_param in self.shared_params:
if self.stage_manager.stage in shared_param:
stage_shared_param = shared_param[self.stage_manager.stage]
working_grad = self._grad_store.get_working_grad_by_param_id(id(stage_shared_param))
if grad is working_grad:
grad_norm_exponentiated /= len(shared_param)
total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
if dp_size > 1:
# compute norm in dp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg)
if tp_size > 1:
# compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
if pp_size > 1:
# compute norm in pp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
# Compute the 'total_norm' from 'total_norm_exponentiated'
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
return total_norm
class HybridParallelPlugin(PipelinePluginBase):
"""
@@ -475,11 +780,19 @@ class HybridParallelPlugin(PipelinePluginBase):
param_info=param_info,
precision=self.precision,
max_norm=self.max_norm,
pp_process_group=self.pp_group,
tp_process_group=self.tp_group,
**self.amp_config,
)
else:
optimizer = HybridParallelNaiveOptimizer(
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
max_norm=self.max_norm,
pp_process_group=self.pp_group,
tp_process_group=self.tp_group,
)
else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
@@ -491,6 +804,7 @@ class HybridParallelPlugin(PipelinePluginBase):
param_info=param_info,
dp_process_group=self.dp_group,
tp_process_group=self.tp_group,
pp_process_group=self.pp_group,
verbose=True,
clip_grad_norm=self.max_norm,
**self.zero_config,