mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[shardformer] optimize seq parallelism (#6086)
* [shardformer] optimize seq parallelism * [shardformer] fix gpt2 fused linear col * [plugin] update gemini plugin * [plugin] update moe hybrid plugin * [test] update gpt2 fused linear test * [shardformer] fix gpt2 fused linear reduce
This commit is contained in:
@@ -41,7 +41,7 @@ class Conv1D(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: bool):
|
||||
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
with ctx:
|
||||
@@ -52,7 +52,6 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b
|
||||
gather_output=True,
|
||||
seq_parallel_mode=seq_parallel_mode,
|
||||
split_sizes=[64] * 3,
|
||||
overlap=overlap,
|
||||
)
|
||||
|
||||
assert linear.weight.shape == torch.Size([48, 192])
|
||||
@@ -121,9 +120,8 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel_mode: bool):
|
||||
|
||||
@parameterize("lazy_init", [False, True])
|
||||
@parameterize("seq_parallel_mode", ["split_gather", None])
|
||||
@parameterize("overlap", [True])
|
||||
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
|
||||
check_linear_conv_1d_col(lazy_init, seq_parallel_mode, overlap)
|
||||
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool):
|
||||
check_linear_conv_1d_col(lazy_init, seq_parallel_mode)
|
||||
check_linear_conv_1d_row(lazy_init, seq_parallel_mode)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user