mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[hotfix] Add layer norm gradients all-reduce for sequence parallel (#4926)
* [hotfix] Add layer norm gradients all-reduce for sequence parallel. (#4915) * Add layer norm gradients all-reduce for sequence parallel. * skip pipeline inference test * [hotfix] fixing polices of sequence parallel (#4922) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy --------- Co-authored-by: littsk <1214689160@qq.com> * Hotfix/add grad all reduce for sequence parallel (#4927) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy * fix bug using wrong variables --------- Co-authored-by: littsk <1214689160@qq.com> * fix policy initialization * fix bloom and chatglm policices * polish code of handling layernorm * fix moe module * polish code of class initializing --------- Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import ctypes
|
||||
import random
|
||||
from contextlib import nullcontext
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
|
||||
@@ -25,6 +25,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.layer.utils import SeqParallelUtils
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.tensor.d_tensor.api import is_distributed_tensor
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
@@ -47,12 +48,17 @@ class HybridParallelModule(ModelWrapper):
|
||||
precision: str,
|
||||
shard_config: ShardConfig,
|
||||
dp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
use_ddp: bool,
|
||||
ddp_config: dict,
|
||||
custom_policy: Policy,
|
||||
) -> None:
|
||||
self.stage_manager = shard_config.pipeline_stage_manager
|
||||
self.shard_config = shard_config
|
||||
self.dp_group = dp_group
|
||||
self.tp_group = tp_group
|
||||
self.use_dpp = use_ddp
|
||||
self.require_grad_sync = True
|
||||
|
||||
shardformer = ShardFormer(shard_config)
|
||||
if custom_policy is not None:
|
||||
@@ -98,19 +104,75 @@ class HybridParallelModule(ModelWrapper):
|
||||
dist.all_reduce(param.grad, group=group)
|
||||
dist.barrier()
|
||||
|
||||
def no_sync(self) -> Iterator[None]:
|
||||
# no sync grads across data parallel
|
||||
return nullcontext()
|
||||
@contextmanager
|
||||
def no_sync(self):
|
||||
r"""
|
||||
A context manager to disable automatic gradient synchronization (all-reduce) and allow manual synchronization
|
||||
when 'no_sync' is active. Alternatively, synchronization will occur in the first forward-backward pass
|
||||
when exiting the context.
|
||||
"""
|
||||
|
||||
def sync_grads(self):
|
||||
# sync grad across data parallel
|
||||
# Store the current value of 'require_grad_sync' to restore it later.
|
||||
old_require_grad_sync = self.require_grad_sync
|
||||
# Disable automatic gradient synchronization.
|
||||
self.require_grad_sync = False
|
||||
try:
|
||||
if self.use_dpp:
|
||||
# If using data parallel processing (use_dpp), disable synchronization too.
|
||||
with self.module.no_sync():
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
finally:
|
||||
# Restore the original value of 'require_grad_sync'.
|
||||
self.require_grad_sync = old_require_grad_sync
|
||||
|
||||
def sync_dp_grads(self):
|
||||
r"""
|
||||
Synchronize gradients across data parallelism (DP) if the DP group size is greater than 1.
|
||||
This function performs an all-reduce operation to combine gradients from different devices in the DP group.
|
||||
|
||||
Args:
|
||||
None
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Check if the DP group size is 1, meaning no synchronization is needed.
|
||||
if self.dp_group.size() == 1:
|
||||
return
|
||||
|
||||
# Iterate through the model's parameters and perform gradient synchronization.
|
||||
for p in self.module.parameters():
|
||||
if p.grad is not None:
|
||||
# Perform all-reduce to combine gradients from different devices.
|
||||
dist.all_reduce(p.grad, group=self.dp_group)
|
||||
# Normalize the gradient by dividing it by the DP group size.
|
||||
p.grad.div_(self.dp_group.size())
|
||||
|
||||
def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None):
|
||||
r"""
|
||||
Synchronize gradients that are partially derived within sequence parallelism
|
||||
if sequence parallelism is enabled. Gradients can be provided explicitly or extracted
|
||||
from the module.
|
||||
|
||||
Args:
|
||||
grads (Optional[List[torch.Tensor]]): A list of gradient tensors to synchronize. If not
|
||||
provided, gradients will be extracted from the model.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism:
|
||||
if grads is not None:
|
||||
# Synchronize provided gradient tensors across the tensor parallelism group.
|
||||
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, grads=grads)
|
||||
else:
|
||||
# Synchronize gradients from the model across the tensor parallelism group.
|
||||
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, model=self.module)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.convert_fn is not None:
|
||||
args = tree_map(self.convert_fn, args)
|
||||
@@ -166,7 +228,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||
def __init__(
|
||||
self,
|
||||
optim: Optimizer,
|
||||
model: Module,
|
||||
model: HybridParallelModule,
|
||||
use_pipeline: bool,
|
||||
param_info: OrderedDict,
|
||||
max_norm: float = 0,
|
||||
@@ -176,13 +238,69 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||
self.param_info = param_info
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optim, model)
|
||||
self.model = model
|
||||
self.stage_manager = model.stage_manager
|
||||
self.shared_params = model.shared_params
|
||||
self.max_norm = max_norm
|
||||
self.tp_pg = tp_process_group
|
||||
self.pp_pg = pp_process_group
|
||||
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||
super().__init__(optim)
|
||||
|
||||
def backward(self, loss: Tensor, *args, **kwargs):
|
||||
r"""
|
||||
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
||||
|
||||
This method performs backward pass for gradient computation. If sequence parallelism is enabled
|
||||
and gradient synchronization is required, it will synchronize gradients that are partially derived
|
||||
within sequence parallelism across tp parallelism groups.
|
||||
|
||||
Args:
|
||||
loss (Tensor): The loss tensor to compute gradients with respect to.
|
||||
*args: Additional positional arguments to be passed to the superclass backward method.
|
||||
**kwargs: Additional keyword arguments to be passed to the superclass backward method.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Call the superclass backward method to compute gradients.
|
||||
super().backward(loss, *args, **kwargs)
|
||||
|
||||
if self.model.require_grad_sync:
|
||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||
self.model.sync_sp_grads()
|
||||
else:
|
||||
# If gradient synchronization is is not required, return.
|
||||
return
|
||||
|
||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
||||
"""
|
||||
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
||||
|
||||
This method performs a backward pass for gradient computation using a precomputed gradient tensor.
|
||||
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
|
||||
gradients that are partially derived within sequence parallelism across tp parallelism groups.
|
||||
|
||||
Args:
|
||||
tensor (Tensor): The input tensor for which gradients are computed.
|
||||
grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Call the superclass backward method to compute gradients.
|
||||
super().backward_by_grad(tensor, grad)
|
||||
|
||||
if self.model.require_grad_sync:
|
||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||
self.model.sync_sp_grads()
|
||||
else:
|
||||
# If gradient synchronization is is not required, return.
|
||||
return
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
r"""
|
||||
Perform an optimization step.
|
||||
@@ -220,8 +338,6 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||
if len(param_gradient_pairs) == 0:
|
||||
return 0.0
|
||||
|
||||
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||
norm_type = float(norm_type)
|
||||
|
||||
# gradients used for norm calculation.
|
||||
@@ -230,9 +346,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||
if norm_type == inf:
|
||||
total_norm = max(grad.data.abs().max() for grad in gradients)
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
if tp_size > 1:
|
||||
if self.tp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||
if pp_size > 1:
|
||||
if self.pp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
|
||||
total_norm = total_norm_cuda.item()
|
||||
else:
|
||||
@@ -250,16 +366,16 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
|
||||
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
|
||||
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
|
||||
if tp_size > 1:
|
||||
if self.tp_size > 1:
|
||||
param_for_grad = grad_to_param_mapping[id(grad)]
|
||||
if not is_distributed_tensor(param_for_grad):
|
||||
grad_norm_exponentiated /= tp_size
|
||||
grad_norm_exponentiated /= self.tp_size
|
||||
|
||||
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
|
||||
# it means that this parameter is used in two different pipeline stages.
|
||||
# To avoid redundant norm calculations, we divide the exponent of this norm by
|
||||
# the number of shared stages.
|
||||
if pp_size > 1:
|
||||
if self.pp_size > 1:
|
||||
for shared_param in self.shared_params:
|
||||
if self.stage_manager.stage in shared_param:
|
||||
stage_shared_param = shared_param[self.stage_manager.stage]
|
||||
@@ -269,10 +385,10 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||
total_norm_exponentiated += grad_norm_exponentiated
|
||||
|
||||
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
||||
if tp_size > 1:
|
||||
if self.tp_size > 1:
|
||||
# compute norm in tp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
||||
if pp_size > 1:
|
||||
if self.pp_size > 1:
|
||||
# compute norm in pp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
|
||||
|
||||
@@ -314,7 +430,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||
def __init__(
|
||||
self,
|
||||
optim: Optimizer,
|
||||
model: Module,
|
||||
model: HybridParallelModule,
|
||||
use_pipeline: bool,
|
||||
param_info: OrderedDict,
|
||||
precision: str = "fp16",
|
||||
@@ -329,11 +445,14 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
pp_process_group: Optional[ProcessGroup] = None, # if using pp
|
||||
):
|
||||
self.model = model
|
||||
self.param_info = param_info
|
||||
self.stage_manager = model.stage_manager
|
||||
self.shared_params = model.shared_params
|
||||
self.tp_pg = tp_process_group
|
||||
self.pp_pg = pp_process_group
|
||||
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optim, model)
|
||||
super().__init__(
|
||||
@@ -349,6 +468,59 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||
max_norm=max_norm,
|
||||
)
|
||||
|
||||
def backward(self, loss: Tensor, *args, **kwargs):
|
||||
r"""
|
||||
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
||||
|
||||
This method performs backward pass for gradient computation. If sequence parallelism is enabled
|
||||
and gradient synchronization is required, it will synchronize gradients that are partially derived
|
||||
within sequence parallelism across tp parallelism groups.
|
||||
|
||||
Args:
|
||||
loss (Tensor): The loss tensor to compute gradients with respect to.
|
||||
*args: Additional positional arguments to be passed to the superclass backward method.
|
||||
**kwargs: Additional keyword arguments to be passed to the superclass backward method.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Call the superclass backward method to compute gradients.
|
||||
super().backward(loss, *args, **kwargs)
|
||||
|
||||
if self.model.require_grad_sync:
|
||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||
self.model.sync_sp_grads()
|
||||
else:
|
||||
# If gradient synchronization is is not required, return.
|
||||
return
|
||||
|
||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
||||
"""
|
||||
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
||||
|
||||
This method performs a backward pass for gradient computation using a precomputed gradient tensor.
|
||||
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
|
||||
gradients that are partially derived within sequence parallelism across tp parallelism groups.
|
||||
|
||||
Args:
|
||||
tensor (Tensor): The input tensor for which gradients are computed.
|
||||
grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Call the superclass backward method to compute gradients.
|
||||
super().backward_by_grad(tensor, grad)
|
||||
|
||||
if self.model.require_grad_sync:
|
||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||
self.model.sync_sp_grads()
|
||||
else:
|
||||
# If gradient synchronization is is not required, return.
|
||||
return
|
||||
|
||||
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
|
||||
r"""
|
||||
Compute and return the gradient norm for gradient clipping.
|
||||
@@ -363,8 +535,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||
if len(param_gradient_pairs) == 0:
|
||||
return 0.0
|
||||
|
||||
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||
norm_type = float(norm_type)
|
||||
|
||||
if norm_type == inf:
|
||||
@@ -374,9 +544,9 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
|
||||
if tp_size > 1:
|
||||
if self.tp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||
if pp_size > 1:
|
||||
if self.pp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
|
||||
|
||||
total_norm = total_norm_cuda.item()
|
||||
@@ -396,16 +566,16 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
|
||||
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
|
||||
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
|
||||
if tp_size > 1:
|
||||
if self.tp_size > 1:
|
||||
param_for_grad = grad_to_param_mapping[id(grad)]
|
||||
if not is_distributed_tensor(param_for_grad):
|
||||
grad_norm_exponentiated /= tp_size
|
||||
grad_norm_exponentiated /= self.tp_size
|
||||
|
||||
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
|
||||
# it means that this parameter is used in two different pipeline stages.
|
||||
# To avoid redundant norm calculations, we divide the exponent of this norm by
|
||||
# the number of shared stages.
|
||||
if pp_size > 1:
|
||||
if self.pp_size > 1:
|
||||
for shared_param in self.shared_params:
|
||||
if self.stage_manager.stage in shared_param:
|
||||
stage_working_shared_param = shared_param[self.stage_manager.stage]
|
||||
@@ -416,10 +586,10 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||
total_norm_exponentiated += grad_norm_exponentiated
|
||||
|
||||
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
||||
if tp_size > 1:
|
||||
if self.tp_size > 1:
|
||||
# compute norm in tp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
||||
if pp_size > 1:
|
||||
if self.pp_size > 1:
|
||||
# compute norm in pp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
|
||||
|
||||
@@ -433,7 +603,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
model: Module,
|
||||
model: HybridParallelModule,
|
||||
use_pipeline: bool,
|
||||
param_info: OrderedDict,
|
||||
initial_scale: int = 2**16, # grad scaler config
|
||||
@@ -455,6 +625,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
pp_process_group: Optional[ProcessGroup] = None, # if using pp
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
self.model = model
|
||||
self.param_info = param_info
|
||||
self.stage_manager = model.stage_manager
|
||||
self.shared_params = model.shared_params
|
||||
@@ -483,6 +654,123 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
forced_dtype=forced_dtype,
|
||||
)
|
||||
|
||||
def sync_dp_grads(self):
|
||||
r"""
|
||||
Synchronize gradients in the data parallelism dimension.
|
||||
|
||||
This method wraps the existing `_sync_grad` method in order to explicitly synchronize gradients
|
||||
in the data parallelism dimension. It is necessary due to the introduction of new parallel dimensions,
|
||||
namely tp (tensor parallelism) and pp (pipeline parallelism). This ensures better code organization
|
||||
and readability.
|
||||
|
||||
Args:
|
||||
None
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Call the superclass `_sync_grad` method to synchronize gradients.
|
||||
super()._sync_grad()
|
||||
|
||||
def _sync_sp_grads(self):
|
||||
r"""
|
||||
Synchronize gradients that are partially derived within sequence parallelism.
|
||||
|
||||
This method is responsible for synchronizing partially derived gradients across tp parallelism groups.
|
||||
It identifies gradients that ara partially derived or not and synchronizes them.
|
||||
If synchronization is required and gradients are found to be synchronized,
|
||||
it performs the synchronization.
|
||||
|
||||
Args:
|
||||
None
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
def _get_all_working_grads() -> List[Tensor]:
|
||||
"""Retrieve all working gradients from different parameter groups."""
|
||||
all_working_grads = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
|
||||
all_working_grads.extend(working_grads)
|
||||
return all_working_grads
|
||||
|
||||
def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]:
|
||||
"""Identify gradients to be synchronized in the sequence parallelism."""
|
||||
grads_to_sync = []
|
||||
for grad in all_working_grads:
|
||||
param_id_for_grad = self._grad_store.get_param_id_for_grad(grad)
|
||||
param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value
|
||||
if SeqParallelUtils.is_sp_partial_derived_param(param_for_grad):
|
||||
grads_to_sync.append(grad)
|
||||
|
||||
if len(grads_to_sync) > 0:
|
||||
return grads_to_sync
|
||||
else:
|
||||
return None
|
||||
|
||||
# Get all working gradients and gradients to be synchronized.
|
||||
all_working_grads = _get_all_working_grads()
|
||||
grads_to_sync = _get_grads_to_sync(all_working_grads)
|
||||
|
||||
if self.require_grad_sync and grads_to_sync is not None:
|
||||
# Synchronize sequence parallelism gradients if required.
|
||||
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_pg, grads=grads_to_sync)
|
||||
else:
|
||||
return
|
||||
|
||||
def backward(self, loss, retain_graph=False):
|
||||
"""
|
||||
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
||||
|
||||
This method performs the backward pass for gradient computation based on a given loss tensor.
|
||||
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
|
||||
gradients that are partially derived within sequence parallelism across TP parallelism groups.
|
||||
|
||||
Args:
|
||||
loss: The loss tensor to compute gradients with respect to.
|
||||
retain_graph (bool): Whether to retain the computation graph.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Call the superclass backward method to compute gradients.
|
||||
super().backward(loss, retain_graph)
|
||||
|
||||
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||
self._sync_sp_grads()
|
||||
else:
|
||||
# If gradient synchronization is is not required, return.
|
||||
return
|
||||
|
||||
def backward_by_grad(self, tensor, grad):
|
||||
"""
|
||||
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
||||
|
||||
This method performs a backward pass for gradient computation based on a precomputed gradient tensor.
|
||||
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
|
||||
gradients that are partially derived within sequence parallelism across TP parallelism groups.
|
||||
|
||||
Args:
|
||||
tensor: The input tensor for which gradients are computed.
|
||||
grad: The precomputed gradient tensor to compute gradients with respect to the input tensor.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Call the superclass backward_by_grad method to compute gradients.
|
||||
super().backward_by_grad(tensor, grad)
|
||||
|
||||
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||
self._sync_sp_grads()
|
||||
else:
|
||||
# If gradient synchronization is is not required, return.
|
||||
return
|
||||
|
||||
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||
r"""
|
||||
Compute and return the gradient norm for gradient clipping.
|
||||
@@ -768,7 +1056,14 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
if not isinstance(model, ModelWrapper):
|
||||
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
||||
model = HybridParallelModule(
|
||||
model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy
|
||||
model,
|
||||
precision=self.precision,
|
||||
shard_config=self.shard_config,
|
||||
dp_group=self.dp_group,
|
||||
tp_group=self.tp_group,
|
||||
use_ddp=use_ddp,
|
||||
ddp_config=self.ddp_config,
|
||||
custom_policy=self.custom_policy,
|
||||
)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if self.zero_stage == 0:
|
||||
@@ -826,17 +1121,32 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
return_outputs: bool = False,
|
||||
) -> dict:
|
||||
assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
|
||||
# return loss or outputs if needed
|
||||
|
||||
# Create a context for gradient synchronization based on the optimizer type.
|
||||
# If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync().
|
||||
# This is to avoid redundant gradient reduction in pipeline parallelism (multiple microbatch values should be reduced once),
|
||||
# so we disable it, performing manual reduction instead.
|
||||
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||
|
||||
with ctx:
|
||||
outputs = self.schedule.forward_backward_step(
|
||||
model, data_iter, criterion, optimizer, return_loss, return_outputs
|
||||
)
|
||||
|
||||
# Synchronize the grads of shared parameters of the model.
|
||||
model.sync_shared_params()
|
||||
|
||||
# Synchronize sequence parallelism gradients of the model.
|
||||
model.sync_sp_grads()
|
||||
|
||||
# Check if the optimizer is a HybridParallelZeroOptimizer and synchronize data parallelism gradients if so.
|
||||
# Otherwise, synchronize data parallelism gradients of the model.
|
||||
# This is because these are two different forms of data parallelism.
|
||||
if isinstance(optimizer, HybridParallelZeroOptimizer):
|
||||
optimizer.sync_grad()
|
||||
optimizer.sync_dp_grads()
|
||||
else:
|
||||
model.sync_grads()
|
||||
model.sync_dp_grads()
|
||||
|
||||
return outputs
|
||||
|
||||
def prepare_dataloader(
|
||||
|
Reference in New Issue
Block a user