[hotfix] Add layer norm gradients all-reduce for sequence parallel (#4926)

* [hotfix] Add layer norm gradients all-reduce for sequence parallel. (#4915)

* Add layer norm gradients all-reduce for sequence parallel.

* skip pipeline inference test

* [hotfix] fixing polices of sequence parallel (#4922)

* Add layer norm gradients all-reduce for sequence parallel.

* fix parameter passing when calling get_autopolicy

---------

Co-authored-by: littsk <1214689160@qq.com>

* Hotfix/add grad all reduce for sequence parallel (#4927)

* Add layer norm gradients all-reduce for sequence parallel.


* fix parameter passing when calling get_autopolicy

* fix bug using wrong variables

---------

Co-authored-by: littsk <1214689160@qq.com>

* fix policy initialization

* fix bloom and chatglm policices

* polish code of handling layernorm

* fix moe module

* polish code of class initializing

---------

Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
This commit is contained in:
littsk
2023-11-03 13:32:43 +08:00
committed by GitHub
parent d99b2c961a
commit 1a3315e336
30 changed files with 1120 additions and 552 deletions

View File

@@ -1,6 +1,6 @@
import ctypes
import random
from contextlib import nullcontext
from contextlib import contextmanager
from functools import partial
from types import MethodType
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
@@ -25,6 +25,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.d_tensor.api import is_distributed_tensor
from colossalai.zero.low_level import LowLevelZeroOptimizer
@@ -47,12 +48,17 @@ class HybridParallelModule(ModelWrapper):
precision: str,
shard_config: ShardConfig,
dp_group: ProcessGroup,
tp_group: ProcessGroup,
use_ddp: bool,
ddp_config: dict,
custom_policy: Policy,
) -> None:
self.stage_manager = shard_config.pipeline_stage_manager
self.shard_config = shard_config
self.dp_group = dp_group
self.tp_group = tp_group
self.use_dpp = use_ddp
self.require_grad_sync = True
shardformer = ShardFormer(shard_config)
if custom_policy is not None:
@@ -98,19 +104,75 @@ class HybridParallelModule(ModelWrapper):
dist.all_reduce(param.grad, group=group)
dist.barrier()
def no_sync(self) -> Iterator[None]:
# no sync grads across data parallel
return nullcontext()
@contextmanager
def no_sync(self):
r"""
A context manager to disable automatic gradient synchronization (all-reduce) and allow manual synchronization
when 'no_sync' is active. Alternatively, synchronization will occur in the first forward-backward pass
when exiting the context.
"""
def sync_grads(self):
# sync grad across data parallel
# Store the current value of 'require_grad_sync' to restore it later.
old_require_grad_sync = self.require_grad_sync
# Disable automatic gradient synchronization.
self.require_grad_sync = False
try:
if self.use_dpp:
# If using data parallel processing (use_dpp), disable synchronization too.
with self.module.no_sync():
yield
else:
yield
finally:
# Restore the original value of 'require_grad_sync'.
self.require_grad_sync = old_require_grad_sync
def sync_dp_grads(self):
r"""
Synchronize gradients across data parallelism (DP) if the DP group size is greater than 1.
This function performs an all-reduce operation to combine gradients from different devices in the DP group.
Args:
None
Returns:
None
"""
# Check if the DP group size is 1, meaning no synchronization is needed.
if self.dp_group.size() == 1:
return
# Iterate through the model's parameters and perform gradient synchronization.
for p in self.module.parameters():
if p.grad is not None:
# Perform all-reduce to combine gradients from different devices.
dist.all_reduce(p.grad, group=self.dp_group)
# Normalize the gradient by dividing it by the DP group size.
p.grad.div_(self.dp_group.size())
def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None):
r"""
Synchronize gradients that are partially derived within sequence parallelism
if sequence parallelism is enabled. Gradients can be provided explicitly or extracted
from the module.
Args:
grads (Optional[List[torch.Tensor]]): A list of gradient tensors to synchronize. If not
provided, gradients will be extracted from the model.
Returns:
None
"""
if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism:
if grads is not None:
# Synchronize provided gradient tensors across the tensor parallelism group.
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, grads=grads)
else:
# Synchronize gradients from the model across the tensor parallelism group.
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, model=self.module)
def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
@@ -166,7 +228,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
def __init__(
self,
optim: Optimizer,
model: Module,
model: HybridParallelModule,
use_pipeline: bool,
param_info: OrderedDict,
max_norm: float = 0,
@@ -176,13 +238,69 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optim, model)
self.model = model
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.max_norm = max_norm
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
super().__init__(optim)
def backward(self, loss: Tensor, *args, **kwargs):
r"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
This method performs backward pass for gradient computation. If sequence parallelism is enabled
and gradient synchronization is required, it will synchronize gradients that are partially derived
within sequence parallelism across tp parallelism groups.
Args:
loss (Tensor): The loss tensor to compute gradients with respect to.
*args: Additional positional arguments to be passed to the superclass backward method.
**kwargs: Additional keyword arguments to be passed to the superclass backward method.
Returns:
None
"""
# Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs)
if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
self.model.sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
"""
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
This method performs a backward pass for gradient computation using a precomputed gradient tensor.
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
gradients that are partially derived within sequence parallelism across tp parallelism groups.
Args:
tensor (Tensor): The input tensor for which gradients are computed.
grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor.
Returns:
None
"""
# Call the superclass backward method to compute gradients.
super().backward_by_grad(tensor, grad)
if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
self.model.sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def step(self, *args, **kwargs):
r"""
Perform an optimization step.
@@ -220,8 +338,6 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
if len(param_gradient_pairs) == 0:
return 0.0
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
norm_type = float(norm_type)
# gradients used for norm calculation.
@@ -230,9 +346,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
if tp_size > 1:
if self.tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
if pp_size > 1:
if self.pp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
total_norm = total_norm_cuda.item()
else:
@@ -250,16 +366,16 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
if tp_size > 1:
if self.tp_size > 1:
param_for_grad = grad_to_param_mapping[id(grad)]
if not is_distributed_tensor(param_for_grad):
grad_norm_exponentiated /= tp_size
grad_norm_exponentiated /= self.tp_size
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
# it means that this parameter is used in two different pipeline stages.
# To avoid redundant norm calculations, we divide the exponent of this norm by
# the number of shared stages.
if pp_size > 1:
if self.pp_size > 1:
for shared_param in self.shared_params:
if self.stage_manager.stage in shared_param:
stage_shared_param = shared_param[self.stage_manager.stage]
@@ -269,10 +385,10 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
if tp_size > 1:
if self.tp_size > 1:
# compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
if pp_size > 1:
if self.pp_size > 1:
# compute norm in pp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
@@ -314,7 +430,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
def __init__(
self,
optim: Optimizer,
model: Module,
model: HybridParallelModule,
use_pipeline: bool,
param_info: OrderedDict,
precision: str = "fp16",
@@ -329,11 +445,14 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
tp_process_group: Optional[ProcessGroup] = None, # if using tp
pp_process_group: Optional[ProcessGroup] = None, # if using pp
):
self.model = model
self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
if use_pipeline:
init_pipeline_optimizer(optim, model)
super().__init__(
@@ -349,6 +468,59 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
max_norm=max_norm,
)
def backward(self, loss: Tensor, *args, **kwargs):
r"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
This method performs backward pass for gradient computation. If sequence parallelism is enabled
and gradient synchronization is required, it will synchronize gradients that are partially derived
within sequence parallelism across tp parallelism groups.
Args:
loss (Tensor): The loss tensor to compute gradients with respect to.
*args: Additional positional arguments to be passed to the superclass backward method.
**kwargs: Additional keyword arguments to be passed to the superclass backward method.
Returns:
None
"""
# Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs)
if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
self.model.sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
"""
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
This method performs a backward pass for gradient computation using a precomputed gradient tensor.
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
gradients that are partially derived within sequence parallelism across tp parallelism groups.
Args:
tensor (Tensor): The input tensor for which gradients are computed.
grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor.
Returns:
None
"""
# Call the superclass backward method to compute gradients.
super().backward_by_grad(tensor, grad)
if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
self.model.sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
r"""
Compute and return the gradient norm for gradient clipping.
@@ -363,8 +535,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
if len(param_gradient_pairs) == 0:
return 0.0
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
norm_type = float(norm_type)
if norm_type == inf:
@@ -374,9 +544,9 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
if tp_size > 1:
if self.tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
if pp_size > 1:
if self.pp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
total_norm = total_norm_cuda.item()
@@ -396,16 +566,16 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
if tp_size > 1:
if self.tp_size > 1:
param_for_grad = grad_to_param_mapping[id(grad)]
if not is_distributed_tensor(param_for_grad):
grad_norm_exponentiated /= tp_size
grad_norm_exponentiated /= self.tp_size
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
# it means that this parameter is used in two different pipeline stages.
# To avoid redundant norm calculations, we divide the exponent of this norm by
# the number of shared stages.
if pp_size > 1:
if self.pp_size > 1:
for shared_param in self.shared_params:
if self.stage_manager.stage in shared_param:
stage_working_shared_param = shared_param[self.stage_manager.stage]
@@ -416,10 +586,10 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
if tp_size > 1:
if self.tp_size > 1:
# compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
if pp_size > 1:
if self.pp_size > 1:
# compute norm in pp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
@@ -433,7 +603,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
def __init__(
self,
optimizer: Optimizer,
model: Module,
model: HybridParallelModule,
use_pipeline: bool,
param_info: OrderedDict,
initial_scale: int = 2**16, # grad scaler config
@@ -455,6 +625,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
pp_process_group: Optional[ProcessGroup] = None, # if using pp
forced_dtype: Optional[torch.dtype] = None,
):
self.model = model
self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
@@ -483,6 +654,123 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
forced_dtype=forced_dtype,
)
def sync_dp_grads(self):
r"""
Synchronize gradients in the data parallelism dimension.
This method wraps the existing `_sync_grad` method in order to explicitly synchronize gradients
in the data parallelism dimension. It is necessary due to the introduction of new parallel dimensions,
namely tp (tensor parallelism) and pp (pipeline parallelism). This ensures better code organization
and readability.
Args:
None
Returns:
None
"""
# Call the superclass `_sync_grad` method to synchronize gradients.
super()._sync_grad()
def _sync_sp_grads(self):
r"""
Synchronize gradients that are partially derived within sequence parallelism.
This method is responsible for synchronizing partially derived gradients across tp parallelism groups.
It identifies gradients that ara partially derived or not and synchronizes them.
If synchronization is required and gradients are found to be synchronized,
it performs the synchronization.
Args:
None
Returns:
None
"""
def _get_all_working_grads() -> List[Tensor]:
"""Retrieve all working gradients from different parameter groups."""
all_working_grads = []
for group_id in range(self.num_param_groups):
working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
all_working_grads.extend(working_grads)
return all_working_grads
def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]:
"""Identify gradients to be synchronized in the sequence parallelism."""
grads_to_sync = []
for grad in all_working_grads:
param_id_for_grad = self._grad_store.get_param_id_for_grad(grad)
param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value
if SeqParallelUtils.is_sp_partial_derived_param(param_for_grad):
grads_to_sync.append(grad)
if len(grads_to_sync) > 0:
return grads_to_sync
else:
return None
# Get all working gradients and gradients to be synchronized.
all_working_grads = _get_all_working_grads()
grads_to_sync = _get_grads_to_sync(all_working_grads)
if self.require_grad_sync and grads_to_sync is not None:
# Synchronize sequence parallelism gradients if required.
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_pg, grads=grads_to_sync)
else:
return
def backward(self, loss, retain_graph=False):
"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
This method performs the backward pass for gradient computation based on a given loss tensor.
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
gradients that are partially derived within sequence parallelism across TP parallelism groups.
Args:
loss: The loss tensor to compute gradients with respect to.
retain_graph (bool): Whether to retain the computation graph.
Returns:
None
"""
# Call the superclass backward method to compute gradients.
super().backward(loss, retain_graph)
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
self._sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def backward_by_grad(self, tensor, grad):
"""
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
This method performs a backward pass for gradient computation based on a precomputed gradient tensor.
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
gradients that are partially derived within sequence parallelism across TP parallelism groups.
Args:
tensor: The input tensor for which gradients are computed.
grad: The precomputed gradient tensor to compute gradients with respect to the input tensor.
Returns:
None
"""
# Call the superclass backward_by_grad method to compute gradients.
super().backward_by_grad(tensor, grad)
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
self._sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
r"""
Compute and return the gradient norm for gradient clipping.
@@ -768,7 +1056,14 @@ class HybridParallelPlugin(PipelinePluginBase):
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(
model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy
model,
precision=self.precision,
shard_config=self.shard_config,
dp_group=self.dp_group,
tp_group=self.tp_group,
use_ddp=use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
@@ -826,17 +1121,32 @@ class HybridParallelPlugin(PipelinePluginBase):
return_outputs: bool = False,
) -> dict:
assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
# return loss or outputs if needed
# Create a context for gradient synchronization based on the optimizer type.
# If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync().
# This is to avoid redundant gradient reduction in pipeline parallelism (multiple microbatch values should be reduced once),
# so we disable it, performing manual reduction instead.
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
with ctx:
outputs = self.schedule.forward_backward_step(
model, data_iter, criterion, optimizer, return_loss, return_outputs
)
# Synchronize the grads of shared parameters of the model.
model.sync_shared_params()
# Synchronize sequence parallelism gradients of the model.
model.sync_sp_grads()
# Check if the optimizer is a HybridParallelZeroOptimizer and synchronize data parallelism gradients if so.
# Otherwise, synchronize data parallelism gradients of the model.
# This is because these are two different forms of data parallelism.
if isinstance(optimizer, HybridParallelZeroOptimizer):
optimizer.sync_grad()
optimizer.sync_dp_grads()
else:
model.sync_grads()
model.sync_dp_grads()
return outputs
def prepare_dataloader(