[shardformer] optimize seq parallelism (#6086)

* [shardformer] optimize seq parallelism

* [shardformer] fix gpt2 fused linear col

* [plugin] update gemini plugin

* [plugin] update moe hybrid plugin

* [test] update gpt2 fused linear test

* [shardformer] fix gpt2 fused linear reduce
This commit is contained in:
Hongxin Liu
2024-10-11 13:44:40 +08:00
committed by GitHub
parent 6b2c506fc5
commit dc2cdaf3e8
13 changed files with 111 additions and 278 deletions

View File

@@ -41,7 +41,7 @@ class Conv1D(nn.Module):
return x
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: bool):
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
with ctx:
@@ -52,7 +52,6 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b
gather_output=True,
seq_parallel_mode=seq_parallel_mode,
split_sizes=[64] * 3,
overlap=overlap,
)
assert linear.weight.shape == torch.Size([48, 192])
@@ -121,9 +120,8 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel_mode: bool):
@parameterize("lazy_init", [False, True])
@parameterize("seq_parallel_mode", ["split_gather", None])
@parameterize("overlap", [True])
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
check_linear_conv_1d_col(lazy_init, seq_parallel_mode, overlap)
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool):
check_linear_conv_1d_col(lazy_init, seq_parallel_mode)
check_linear_conv_1d_row(lazy_init, seq_parallel_mode)