mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 13:11:05 +00:00
[shardformer] chatglm support sequence parallel (#4482)
* [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel * fix fix fix fix
This commit is contained in:
@@ -74,6 +74,7 @@ class Linear1D_Col(ParallelModule):
|
||||
process_group: ProcessGroup = None,
|
||||
gather_output: bool = False,
|
||||
seq_parallel: bool = False,
|
||||
seq_parallel_dim: int = 1,
|
||||
overlap: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
@@ -87,6 +88,7 @@ class Linear1D_Col(ParallelModule):
|
||||
self.out_features = out_features
|
||||
self.gather_output = gather_output
|
||||
self.seq_parallel = seq_parallel
|
||||
self.seq_parallel_dim = seq_parallel_dim
|
||||
self.overlap = overlap
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.device = device
|
||||
@@ -190,7 +192,8 @@ class Linear1D_Col(ParallelModule):
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
if self.seq_parallel:
|
||||
output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
|
||||
self.process_group, True, 1, self.overlap)
|
||||
self.process_group, True,
|
||||
self.seq_parallel_dim, self.overlap)
|
||||
else:
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
||||
|
||||
@@ -236,6 +239,7 @@ class Linear1D_Row(ParallelModule):
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
seq_parallel: bool = False,
|
||||
seq_parallel_dim: int = 1,
|
||||
parallel_input: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
@@ -254,6 +258,7 @@ class Linear1D_Row(ParallelModule):
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.process_group = process_group
|
||||
self.seq_parallel = seq_parallel
|
||||
self.seq_parallel_dim = seq_parallel_dim
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
@@ -390,7 +395,8 @@ class Linear1D_Row(ParallelModule):
|
||||
else:
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
if self.seq_parallel:
|
||||
output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
|
||||
output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group,
|
||||
self.seq_parallel_dim)
|
||||
else:
|
||||
output = reduce_forward(output_parallel, self.process_group)
|
||||
|
||||
|
Reference in New Issue
Block a user