[shardformer] Add overlap support for gpt2 (#4535)

* add overlap support for gpt2

* remove unused code

* remove unused code
This commit is contained in:
Bin Jia
2023-08-29 18:30:50 +08:00
committed by GitHub
parent 0387a47e63
commit e241b74f24
5 changed files with 120 additions and 94 deletions

View File

@@ -53,7 +53,7 @@ def rearrange(tensor: torch.Tensor, dim: int):
return rearanged_tensor
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool):
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
with ctx:
@@ -62,7 +62,8 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool):
process_group=None,
gather_output=True,
seq_parallel=seq_parallel,
n_fused=3)
n_fused=3,
overlap=overlap)
assert linear.weight.shape == torch.Size([48, 192])
assert linear.bias.shape == torch.Size([192])
@@ -129,8 +130,9 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool):
@parameterize('lazy_init', [False, True])
@parameterize('seq_parallel', [False, True])
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool):
check_linear_conv_1d_col(lazy_init, seq_parallel)
@parameterize('overlap', [True])
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool, overlap: bool):
check_linear_conv_1d_col(lazy_init, seq_parallel, overlap)
check_linear_conv_1d_row(lazy_init, seq_parallel)