mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 23:18:36 +00:00
[shardformer] Add overlap support for gpt2 (#4535)
* add overlap support for gpt2 * remove unused code * remove unused code
This commit is contained in:
@@ -177,6 +177,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||
async_communication: bool = False,
|
||||
gather_output: bool = False,
|
||||
seq_parallel: bool = False,
|
||||
overlap: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
n_fused: int = 3,
|
||||
weight: Optional[Parameter] = None,
|
||||
@@ -190,6 +191,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||
self.out_features = out_features
|
||||
self.gather_output = gather_output
|
||||
self.seq_parallel = seq_parallel
|
||||
self.overlap = overlap
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.device = device
|
||||
self.n_fused = n_fused
|
||||
@@ -308,7 +310,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||
if self.seq_parallel:
|
||||
input_parallel = input_
|
||||
output_parallel = matmul_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
|
||||
self.process_group, True, 1)
|
||||
self.process_group, True, 1, self.overlap)
|
||||
else:
|
||||
# Set up backprop all-reduce.
|
||||
input_parallel = reduce_backward(input_, self.process_group)
|
||||
|
Reference in New Issue
Block a user