mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[shardformer] Sequence Parallelism Optimization (#5533)
* sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
This commit is contained in:
@@ -34,7 +34,8 @@ from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
|
||||
from .pp_plugin_base import PipelinePluginBase
|
||||
|
||||
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
|
||||
DP_AXIS, PP_AXIS, TP_AXIS, SP_AXIS = 0, 1, 2, 3
|
||||
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
|
||||
|
||||
PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}
|
||||
|
||||
@@ -53,6 +54,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||
shard_config: ShardConfig,
|
||||
dp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
sp_group: ProcessGroup,
|
||||
use_ddp: bool,
|
||||
ddp_config: dict,
|
||||
custom_policy: Policy,
|
||||
@@ -61,6 +63,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||
self.shard_config = shard_config
|
||||
self.dp_group = dp_group
|
||||
self.tp_group = tp_group
|
||||
self.sp_group = sp_group
|
||||
self.use_dpp = use_ddp
|
||||
self.require_grad_sync = True
|
||||
|
||||
@@ -168,13 +171,24 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism:
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
if self.shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
return
|
||||
|
||||
if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||
# If sequence parallelism is enabled and mode is split_gather or ring, gradients are synchronized
|
||||
# across the tensor parallelism group.
|
||||
group = self.tp_group
|
||||
else:
|
||||
raise ValueError(f"Unknown sequence parallelism mode: {self.shard_config.sequence_parallelism_mode}")
|
||||
|
||||
if grads is not None:
|
||||
# Synchronize provided gradient tensors across the tensor parallelism group.
|
||||
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, grads=grads)
|
||||
SeqParallelUtils.allreduce_partial_data_grad(process_group=group, grads=grads)
|
||||
else:
|
||||
# Synchronize gradients from the model across the tensor parallelism group.
|
||||
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, model=self.module)
|
||||
SeqParallelUtils.allreduce_partial_data_grad(process_group=group, model=self.module)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.convert_fn is not None:
|
||||
@@ -727,10 +741,9 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
# Get all working gradients and gradients to be synchronized.
|
||||
all_working_grads = _get_all_working_grads()
|
||||
grads_to_sync = _get_grads_to_sync(all_working_grads)
|
||||
|
||||
if self.require_grad_sync and grads_to_sync is not None:
|
||||
# Synchronize sequence parallelism gradients if required.
|
||||
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_pg, grads=grads_to_sync)
|
||||
SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync)
|
||||
else:
|
||||
return
|
||||
|
||||
@@ -891,6 +904,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
Args:
|
||||
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
|
||||
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
|
||||
sp_size (int): The size of sequence parallelism.
|
||||
precision (str, optional): Specifies the precision of parameters during training.
|
||||
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
|
||||
Defaults to 'fp16'.
|
||||
@@ -903,6 +917,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
|
||||
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
||||
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
|
||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
|
||||
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
|
||||
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
||||
@@ -938,6 +953,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
self,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
sp_size: int = None,
|
||||
precision: str = "fp16",
|
||||
zero_stage: int = 0,
|
||||
enable_all_optimization: bool = False,
|
||||
@@ -945,6 +961,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
enable_flash_attention: bool = False,
|
||||
enable_jit_fused: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
sequence_parallelism_mode: str = None,
|
||||
enable_sequence_overlap: bool = False,
|
||||
parallel_output: bool = True,
|
||||
num_microbatches: Optional[int] = None,
|
||||
@@ -976,14 +993,41 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
super().__init__()
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||
|
||||
if enable_sequence_parallelism:
|
||||
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
|
||||
self.sequence_parallelism_mode = sequence_parallelism_mode if sequence_parallelism_mode is not None else "1"
|
||||
assert (
|
||||
self.sequence_parallelism_mode in SUPPORT_SP_MODE
|
||||
), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}"
|
||||
if self.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||
assert (
|
||||
tp_size > 1
|
||||
), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
|
||||
if sp_size != 1:
|
||||
warnings.warn(
|
||||
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size."
|
||||
)
|
||||
self.sp_size = 1
|
||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||
elif self.sequence_parallelism_mode in ["all_to_all"]:
|
||||
assert (
|
||||
tp_size == 1
|
||||
), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with tensor parallelism"
|
||||
assert (
|
||||
pp_size == 1
|
||||
), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with pipeline parallelism"
|
||||
self.sp_size = dist.get_world_size() if sp_size is None else sp_size
|
||||
self.dp_size = dist.get_world_size() // (self.sp_size * pp_size)
|
||||
else:
|
||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||
assert (
|
||||
sp_size == 1 or sp_size is None
|
||||
), f"sp_size can only be set to a >1 number when enable_sequence_parallelism is True"
|
||||
self.sp_size = 1
|
||||
|
||||
self.tp_size = tp_size
|
||||
self.pp_size = pp_size
|
||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||
self.precision = precision
|
||||
self.zero_stage = zero_stage
|
||||
self.cpu_offload = cpu_offload
|
||||
@@ -992,7 +1036,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
self.enable_flash_attention = enable_flash_attention
|
||||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
|
||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
|
||||
self.stage_manager = None
|
||||
self.schedule = None
|
||||
self.custom_policy = custom_policy
|
||||
@@ -1033,9 +1077,14 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
||||
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
|
||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
else:
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS)
|
||||
|
||||
self.shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=self.tp_group,
|
||||
sequence_parallel_process_group=self.sp_group,
|
||||
pipeline_stage_manager=self.stage_manager,
|
||||
enable_tensor_parallelism=self.tp_size > 1,
|
||||
enable_all_optimization=self.enable_all_optimization,
|
||||
@@ -1043,6 +1092,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
enable_flash_attention=self.enable_flash_attention,
|
||||
enable_jit_fused=self.enable_jit_fused,
|
||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||
sequence_parallelism_mode=sequence_parallelism_mode,
|
||||
enable_sequence_overlap=enable_sequence_overlap,
|
||||
parallel_output=parallel_output,
|
||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||
@@ -1113,13 +1163,23 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
param_info = get_param_info(optimizer)
|
||||
if not isinstance(model, ModelWrapper):
|
||||
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
||||
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
|
||||
self.dp_size == 1
|
||||
and self.pp_size == 1
|
||||
and self.enable_sequence_parallelism
|
||||
and self.sequence_parallelism_mode == "all_to_all"
|
||||
)
|
||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
|
||||
dp_group = self.pg_mesh.create_group_along_axis([DP_AXIS, SP_AXIS])
|
||||
else:
|
||||
dp_group = self.dp_group
|
||||
model = HybridParallelModule(
|
||||
model,
|
||||
precision=self.precision,
|
||||
shard_config=self.shard_config,
|
||||
dp_group=self.dp_group,
|
||||
dp_group=dp_group,
|
||||
tp_group=self.tp_group,
|
||||
sp_group=self.sp_group,
|
||||
use_ddp=use_ddp,
|
||||
ddp_config=self.ddp_config,
|
||||
custom_policy=self.custom_policy,
|
||||
@@ -1149,7 +1209,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
tp_process_group=self.tp_group,
|
||||
)
|
||||
else:
|
||||
if self.dp_size == 1:
|
||||
zero_dp_size = dist.get_world_size(dp_group)
|
||||
if zero_dp_size == 1:
|
||||
warnings.warn(
|
||||
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
|
||||
"If you are not intended to use cpu_offload, please consider set zero_stage=0."
|
||||
@@ -1161,7 +1222,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
dp_process_group=self.dp_group,
|
||||
dp_process_group=dp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
pp_process_group=self.pp_group,
|
||||
verbose=True,
|
||||
|
Reference in New Issue
Block a user