mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-14 05:33:23 +00:00
[shardformer] Sequence Parallelism Optimization (#5533)
* sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
This commit is contained in:
@@ -23,11 +23,13 @@ from colossalai.tensor.d_tensor.api import (
|
||||
)
|
||||
|
||||
from ._operation import (
|
||||
gather_forward_reducescatter_backward,
|
||||
gather_forward_split_backward,
|
||||
linear_gather_forward_reducescatter_backward,
|
||||
linear_reducescatter_forward_gather_backward,
|
||||
linear_with_async_comm,
|
||||
reduce_forward,
|
||||
reducescatter_forward_gather_backward,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from .parallel_module import ParallelModule
|
||||
@@ -74,7 +76,7 @@ class Linear1D_Col(ParallelModule):
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
gather_output: bool = False,
|
||||
seq_parallel: bool = False,
|
||||
seq_parallel_mode: str = None,
|
||||
seq_parallel_dim: int = 1,
|
||||
overlap: torch.cuda.Stream = None,
|
||||
skip_bias_add: bool = False,
|
||||
@@ -89,7 +91,7 @@ class Linear1D_Col(ParallelModule):
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.gather_output = gather_output
|
||||
self.seq_parallel = seq_parallel
|
||||
self.seq_parallel_mode = seq_parallel_mode
|
||||
self.seq_parallel_dim = seq_parallel_dim
|
||||
self.overlap = overlap
|
||||
self.skip_bias_add = skip_bias_add
|
||||
@@ -196,12 +198,18 @@ class Linear1D_Col(ParallelModule):
|
||||
|
||||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
if self.seq_parallel:
|
||||
output_parallel = linear_gather_forward_reducescatter_backward(
|
||||
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap
|
||||
)
|
||||
else:
|
||||
|
||||
if self.seq_parallel_mode is None:
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
||||
elif self.seq_parallel_mode == "split_gather":
|
||||
input_parallel = gather_forward_reducescatter_backward(
|
||||
input_parallel, self.process_group, self.seq_parallel_dim
|
||||
)
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
output_parallel = linear_gather_forward_reducescatter_backward(
|
||||
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
|
||||
)
|
||||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
@@ -225,7 +233,8 @@ class Linear1D_Row(ParallelModule):
|
||||
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
|
||||
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
||||
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
|
||||
seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None.
|
||||
seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence.
|
||||
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||
which is preserved for kernel fusion, defaults to False
|
||||
weight_initializer (:class:`typing.Callable`, optional):
|
||||
@@ -245,7 +254,7 @@ class Linear1D_Row(ParallelModule):
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
seq_parallel: bool = False,
|
||||
seq_parallel_mode: str = None,
|
||||
seq_parallel_dim: int = 1,
|
||||
parallel_input: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
@@ -265,7 +274,7 @@ class Linear1D_Row(ParallelModule):
|
||||
self.parallel_input = parallel_input
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.process_group = process_group
|
||||
self.seq_parallel = seq_parallel
|
||||
self.seq_parallel_mode = seq_parallel_mode
|
||||
self.seq_parallel_dim = seq_parallel_dim
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
|
||||
@@ -403,18 +412,26 @@ class Linear1D_Row(ParallelModule):
|
||||
output_parallel_list[i], group=self.process_group, async_op=True
|
||||
)
|
||||
handle_list.append(handle)
|
||||
# output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
|
||||
for handle in handle_list:
|
||||
handle.wait()
|
||||
output = torch.cat(output_parallel_list, dim=-1)
|
||||
else:
|
||||
output_parallel = linear_with_async_comm(input_, self.weight, None, None, False)
|
||||
if self.seq_parallel:
|
||||
output = linear_reducescatter_forward_gather_backward(
|
||||
if self.seq_parallel_mode is None:
|
||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
||||
output = reduce_forward(output_parallel, self.process_group)
|
||||
elif self.seq_parallel_mode == "split_gather":
|
||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
||||
output = reducescatter_forward_gather_backward(
|
||||
output_parallel, self.process_group, self.seq_parallel_dim
|
||||
)
|
||||
else:
|
||||
output = reduce_forward(output_parallel, self.process_group)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
output = linear_reducescatter_forward_gather_backward(
|
||||
input_,
|
||||
self.weight,
|
||||
process_group=self.process_group,
|
||||
dim=self.seq_parallel_dim,
|
||||
ring=True,
|
||||
)
|
||||
|
||||
if not self.skip_bias_add:
|
||||
if self.bias is not None:
|
||||
|
Reference in New Issue
Block a user