mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[shardformer] Sequence Parallelism Optimization (#5533)
* sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
This commit is contained in:
@@ -35,17 +35,21 @@ class SeqParallelUtils:
|
||||
return getattr(param, "partial_derived", False)
|
||||
|
||||
@staticmethod
|
||||
def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None, grads: List[torch.Tensor] = None):
|
||||
def allreduce_partial_data_grad(
|
||||
process_group: ProcessGroup,
|
||||
model: nn.Module = None,
|
||||
grads: List[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Allreduce partial derived gradients across the specified process group.
|
||||
|
||||
This function performs gradient synchronization for parameters that are marked as partially derived in sequence parallelism.
|
||||
|
||||
Args:
|
||||
tp_group (ProcessGroup): The process group for gradient synchronization.
|
||||
process_group (ProcessGroup): The process group for gradient synchronization.
|
||||
model (nn.Module): The model from which gradients will be synchronized.
|
||||
grads (List[torch.Tensor]): The list of gradients to be synchronized.
|
||||
|
||||
only_sp_partial (bool): Whether handle all the parameters or only parameters marked as partial derived.
|
||||
Raises:
|
||||
AssertionError: If both `model` and `grads` are provided or neither is provided.
|
||||
"""
|
||||
@@ -53,22 +57,26 @@ class SeqParallelUtils:
|
||||
assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None."
|
||||
|
||||
# Get the size of the process group, which determines whether synchronization is needed.
|
||||
tp_size = get_world_size(tp_group) if tp_group is not None else 1
|
||||
group_size = get_world_size(process_group) if process_group is not None else 1
|
||||
|
||||
if tp_size == 1:
|
||||
if group_size == 1:
|
||||
# If the process group size is 1, no synchronization is required.
|
||||
return
|
||||
|
||||
if model is not None:
|
||||
# If `model` is provided, extract partial derived gradients from the model's parameters.
|
||||
grads = []
|
||||
|
||||
for p in model.parameters():
|
||||
if p.grad is not None and SeqParallelUtils.is_sp_partial_derived_param(p):
|
||||
grads.append(p.grad.data)
|
||||
if p.grad is not None:
|
||||
if SeqParallelUtils.is_sp_partial_derived_param(p):
|
||||
grads.append(p.grad.data)
|
||||
|
||||
# Flatten and reduce the gradients using the specified process group.
|
||||
if len(grads) == 0:
|
||||
return
|
||||
coalesced = _flatten_dense_tensors(grads)
|
||||
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
|
||||
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=process_group)
|
||||
|
||||
# Unflatten the synchronized gradients and update the model's gradients.
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||
@@ -76,7 +84,7 @@ class SeqParallelUtils:
|
||||
else:
|
||||
# If `grads` are provided explicitly, synchronize those gradients directly.
|
||||
coalesced = _flatten_dense_tensors(grads)
|
||||
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
|
||||
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=process_group)
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
|
||||
|
Reference in New Issue
Block a user