[shardformer] Sequence Parallelism Optimization (#5533)

* sequence parallel optimization

* validate sequence parallel in llama (code to be polished)

* shardformer api writing

* integrate sequence parallel in ShardFormer

* fix pp bugs and sp bugs for LlaMa model

* integrating ring-based sequence parallelism into ShardFormer

* [sequence parallelism]: Add fused megatron function

* integrating ring-based sequence parallelism into ShardFormer

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>

* fix bugs when useing sp and flashattention together

* fix operation function name

* support flash attention for ulysses-style sp

* clarify sp process group

* fix compatibility bugs in moe plugin

* fix fused linear bugs

* fix linear layer test

* support gpt model all-to-all sp

* modify shard data dimension (meant to be dim=-1)

* support megtron-style sp and distributed attn for llama model

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* finish sp mode 3 support for gpt

* using all_to_all_single when batch size is 1

* support mode 2 sp in gpt2 (#5)

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* refactor ring implementation

* support mode 2 sp in gpt2

* polish code

* enable distributed attn mask when using sp mode 2 and 3 in llama

* automatically enable flash attn when using sp mode 2 and 3 in llama

* inplace attn mask

* add zero2 support for sequence parallel

* polish code

* fix bugs

* fix gemini checkpoint io

* loose tensor checking atol and rtol

* add comment

* fix llama layernorm grad

* fix zero grad

* fix zero grad

* fix conflict

* update split and gather auto grad func

* sequence parallel: inside text split (#6)

* polish code (part 1)

* polish code (part 2)

* polish code (part 2.5)

* polish code (part 3)

* sequence parallel: inside text split

* miscellaneous minor fixes

* polish code

* fix ulysses style ZeRO

* sequence parallel: inside text split

* miscellaneous minor fixes

* disaggregate sp group and dp group for  sp

* fix llama and gpt sp

* polish code

* move ulysses grad sync to ddp (#9)

* remove zero_stage and unbind the grad sync for alltoall sp

* add 2d group creation test

* move ulysses grad sync to ddp

* add 2d group creation test

* remove useless code

* change shard config not to enable sp when enable_all_optimizations

* add sp warnings for several model

* remove useless code

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
This commit is contained in:
Zhongkai Zhao
2024-04-03 17:15:47 +08:00
committed by GitHub
parent 7e0ec5a85c
commit 8e412a548e
33 changed files with 1630 additions and 256 deletions

View File

@@ -1,3 +1,4 @@
import warnings
from functools import partial
from typing import Callable, Dict, List
@@ -66,8 +67,17 @@ class BertPolicy(Policy):
else:
norm_cls = col_nn.LayerNorm
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for Bert"
if sp_mode == "ring":
warnings.warn(
f"For Bert, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
)
sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode == "split_gather"
if self.shard_config.enable_tensor_parallelism:
policy[BertLayer] = ModulePolicyDescription(
attribute_replacement={
@@ -85,7 +95,7 @@ class BertPolicy(Policy):
suffix="attention.self.query",
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel": use_sequence_parallel,
"seq_parallel_mode": sp_mode,
"overlap": overlap,
},
),
@@ -93,7 +103,7 @@ class BertPolicy(Policy):
suffix="attention.self.key",
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel": use_sequence_parallel,
"seq_parallel_mode": sp_mode,
"overlap": overlap,
},
),
@@ -101,7 +111,7 @@ class BertPolicy(Policy):
suffix="attention.self.value",
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel": use_sequence_parallel,
"seq_parallel_mode": sp_mode,
"overlap": overlap,
},
),
@@ -112,7 +122,7 @@ class BertPolicy(Policy):
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
kwargs={"seq_parallel_mode": sp_mode},
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
@@ -122,14 +132,14 @@ class BertPolicy(Policy):
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel": use_sequence_parallel,
"seq_parallel_mode": sp_mode,
"overlap": overlap,
},
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
kwargs={"seq_parallel_mode": sp_mode},
),
SubModuleReplacementDescription(
suffix="output.dropout",
@@ -151,7 +161,7 @@ class BertPolicy(Policy):
]
)
if use_sequence_parallel:
if sp_mode == "split_gather":
self.append_or_create_method_replacement(
description={"forward": bert_sequence_parallel_forward_fn(self.shard_config)},
policy=policy,
@@ -165,12 +175,12 @@ class BertPolicy(Policy):
SubModuleReplacementDescription(
suffix="attention.output.LayerNorm",
target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel},
kwargs={"sp_partial_derived": sp_partial_derived},
),
SubModuleReplacementDescription(
suffix="output.LayerNorm",
target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel},
kwargs={"sp_partial_derived": sp_partial_derived},
),
],
policy=policy,