[shardformer] Add overlap support for gpt2 (#4535)

* add overlap support for gpt2

* remove unused code

* remove unused code
This commit is contained in:
Bin Jia
2023-08-29 18:30:50 +08:00
committed by GitHub
parent 0387a47e63
commit e241b74f24
5 changed files with 120 additions and 94 deletions

View File

@@ -177,6 +177,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
async_communication: bool = False,
gather_output: bool = False,
seq_parallel: bool = False,
overlap: bool = False,
skip_bias_add: bool = False,
n_fused: int = 3,
weight: Optional[Parameter] = None,
@@ -190,6 +191,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.out_features = out_features
self.gather_output = gather_output
self.seq_parallel = seq_parallel
self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
self.n_fused = n_fused
@@ -308,7 +310,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
if self.seq_parallel:
input_parallel = input_
output_parallel = matmul_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
self.process_group, True, 1)
self.process_group, True, 1, self.overlap)
else:
# Set up backprop all-reduce.
input_parallel = reduce_backward(input_, self.process_group)