mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[hotfix] Add layer norm gradients all-reduce for sequence parallel (#4926)
* [hotfix] Add layer norm gradients all-reduce for sequence parallel. (#4915) * Add layer norm gradients all-reduce for sequence parallel. * skip pipeline inference test * [hotfix] fixing polices of sequence parallel (#4922) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy --------- Co-authored-by: littsk <1214689160@qq.com> * Hotfix/add grad all reduce for sequence parallel (#4927) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy * fix bug using wrong variables --------- Co-authored-by: littsk <1214689160@qq.com> * fix policy initialization * fix bloom and chatglm policices * polish code of handling layernorm * fix moe module * polish code of class initializing --------- Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
This commit is contained in:
@@ -1,8 +1,82 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch import nn
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from torch.distributed import ProcessGroup, get_world_size
|
||||
|
||||
|
||||
class SeqParallelUtils:
|
||||
@staticmethod
|
||||
def marked_as_sp_partial_derived_param(param):
|
||||
"""
|
||||
Mark a parameter as partially derived in sequence parallelism.
|
||||
|
||||
Args:
|
||||
param: The parameter to mark as partially derived.
|
||||
"""
|
||||
setattr(param, "partial_derived", True)
|
||||
|
||||
@staticmethod
|
||||
def is_sp_partial_derived_param(param):
|
||||
"""
|
||||
Check if a parameter is marked as partially derived in sequence parallelism.
|
||||
|
||||
Args:
|
||||
param: The parameter to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the parameter is marked as partially derived, False otherwise.
|
||||
"""
|
||||
return getattr(param, "partial_derived", False)
|
||||
|
||||
@staticmethod
|
||||
def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None, grads: List[torch.Tensor] = None):
|
||||
"""
|
||||
Allreduce partial derived gradients across the specified process group.
|
||||
|
||||
This function performs gradient synchronization for parameters that are marked as partially derived in sequence parallelism.
|
||||
|
||||
Args:
|
||||
tp_group (ProcessGroup): The process group for gradient synchronization.
|
||||
model (nn.Module): The model from which gradients will be synchronized.
|
||||
grads (List[torch.Tensor]): The list of gradients to be synchronized.
|
||||
|
||||
Raises:
|
||||
AssertionError: If both `model` and `grads` are provided or neither is provided.
|
||||
"""
|
||||
# Ensure that exactly one of `model` and `grads` is provided for gradient synchronization.
|
||||
assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None."
|
||||
|
||||
# Get the size of the process group, which determines whether synchronization is needed.
|
||||
tp_size = get_world_size(tp_group) if tp_group is not None else 1
|
||||
|
||||
if tp_size == 1:
|
||||
# If the process group size is 1, no synchronization is required.
|
||||
return
|
||||
|
||||
if model is not None:
|
||||
# If `model` is provided, extract partial derived gradients from the model's parameters.
|
||||
grads = []
|
||||
for p in model.parameters():
|
||||
if p.grad is not None and SeqParallelUtils.is_sp_partial_derived_param(p):
|
||||
grads.append(p.grad.data)
|
||||
|
||||
# Flatten and reduce the gradients using the specified process group.
|
||||
coalesced = _flatten_dense_tensors(grads)
|
||||
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
|
||||
|
||||
# Unflatten the synchronized gradients and update the model's gradients.
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
else:
|
||||
# If `grads` are provided explicitly, synchronize those gradients directly.
|
||||
coalesced = _flatten_dense_tensors(grads)
|
||||
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
|
||||
|
||||
class Randomizer:
|
||||
|
Reference in New Issue
Block a user