[shardformer] chatglm support sequence parallel (#4482)

* [shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

* fix

fix

fix

fix
This commit is contained in:
flybird11111
2023-08-22 23:59:31 +08:00
committed by GitHub
parent 351351a36e
commit 59e252ecdb
11 changed files with 259 additions and 94 deletions

View File

@@ -74,6 +74,7 @@ class Linear1D_Col(ParallelModule):
process_group: ProcessGroup = None,
gather_output: bool = False,
seq_parallel: bool = False,
seq_parallel_dim: int = 1,
overlap: bool = False,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
@@ -87,6 +88,7 @@ class Linear1D_Col(ParallelModule):
self.out_features = out_features
self.gather_output = gather_output
self.seq_parallel = seq_parallel
self.seq_parallel_dim = seq_parallel_dim
self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
@@ -190,7 +192,8 @@ class Linear1D_Col(ParallelModule):
bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel:
output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
self.process_group, True, 1, self.overlap)
self.process_group, True,
self.seq_parallel_dim, self.overlap)
else:
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
@@ -236,6 +239,7 @@ class Linear1D_Row(ParallelModule):
device: torch.device = None,
process_group: ProcessGroup = None,
seq_parallel: bool = False,
seq_parallel_dim: int = 1,
parallel_input: bool = True,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
@@ -254,6 +258,7 @@ class Linear1D_Row(ParallelModule):
self.skip_bias_add = skip_bias_add
self.process_group = process_group
self.seq_parallel = seq_parallel
self.seq_parallel_dim = seq_parallel_dim
self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias:
@@ -390,7 +395,8 @@ class Linear1D_Row(ParallelModule):
else:
output_parallel = F.linear(input_, self.weight)
if self.seq_parallel:
output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group,
self.seq_parallel_dim)
else:
output = reduce_forward(output_parallel, self.process_group)