mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[shardformer] Add overlap support for gpt2 (#4535)
* add overlap support for gpt2 * remove unused code * remove unused code
This commit is contained in:
@@ -53,7 +53,7 @@ def rearrange(tensor: torch.Tensor, dim: int):
|
||||
return rearanged_tensor
|
||||
|
||||
|
||||
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool):
|
||||
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
with ctx:
|
||||
@@ -62,7 +62,8 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool):
|
||||
process_group=None,
|
||||
gather_output=True,
|
||||
seq_parallel=seq_parallel,
|
||||
n_fused=3)
|
||||
n_fused=3,
|
||||
overlap=overlap)
|
||||
|
||||
assert linear.weight.shape == torch.Size([48, 192])
|
||||
assert linear.bias.shape == torch.Size([192])
|
||||
@@ -129,8 +130,9 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool):
|
||||
|
||||
@parameterize('lazy_init', [False, True])
|
||||
@parameterize('seq_parallel', [False, True])
|
||||
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool):
|
||||
check_linear_conv_1d_col(lazy_init, seq_parallel)
|
||||
@parameterize('overlap', [True])
|
||||
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
||||
check_linear_conv_1d_col(lazy_init, seq_parallel, overlap)
|
||||
check_linear_conv_1d_row(lazy_init, seq_parallel)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user