mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[shardformer] Sequence Parallelism Optimization (#5533)
* sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
This commit is contained in:
@@ -44,7 +44,10 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
|
||||
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
bert_model = model_fn()
|
||||
enable_all_optimization = True if tp_size > 1 else False
|
||||
|
||||
enable_flash_attention = True if tp_size > 1 else False
|
||||
enable_fused_normalization = True if tp_size > 1 else False
|
||||
enable_jit_fused = True if tp_size > 1 else False
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
pretrained_path = os.path.join(tempdir, "pretrained")
|
||||
@@ -54,7 +57,9 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
|
||||
plugin = GeminiPlugin(
|
||||
**placement_config,
|
||||
tp_size=tp_size,
|
||||
enable_all_optimization=enable_all_optimization,
|
||||
enable_flash_attention=enable_flash_attention,
|
||||
enable_fused_normalization=enable_fused_normalization,
|
||||
enable_jit_fused=enable_jit_fused,
|
||||
extra_dp_size=extra_dp_size,
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
@@ -80,7 +85,9 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
|
||||
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int):
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
criterion = lambda x: x.mean()
|
||||
enable_all_optimization = True if tp_size > 1 else False
|
||||
enable_flash_attention = True if tp_size > 1 else False
|
||||
enable_fused_normalization = True if tp_size > 1 else False
|
||||
enable_jit_fused = True if tp_size > 1 else False
|
||||
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
|
||||
plugin = GeminiPlugin(
|
||||
**placement_config,
|
||||
@@ -88,7 +95,9 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
|
||||
initial_scale=(2**14),
|
||||
tp_size=tp_size,
|
||||
extra_dp_size=extra_dp_size,
|
||||
enable_all_optimization=enable_all_optimization,
|
||||
enable_flash_attention=enable_flash_attention,
|
||||
enable_fused_normalization=enable_fused_normalization,
|
||||
enable_jit_fused=enable_jit_fused,
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
|
Reference in New Issue
Block a user