mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[shardformer] fix linear 1d row and support uneven splits for fused qkv linear (#6084)
* [tp] hotfix linear row * [tp] support uneven split for fused linear * [tp] support sp for fused linear * [tp] fix gpt2 mlp policy * [tp] fix gather fused and add fused linear row
This commit is contained in:
@@ -428,11 +428,8 @@ class Linear1D_Row(ParallelModule):
|
||||
handle.wait()
|
||||
output = torch.cat(output_parallel_list, dim=-1)
|
||||
else:
|
||||
if self.seq_parallel_mode is None:
|
||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
||||
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
||||
elif self.seq_parallel_mode == "split_gather":
|
||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
||||
if self.seq_parallel_mode == "split_gather":
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
output = reducescatter_forward_gather_backward(
|
||||
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||
)
|
||||
@@ -445,8 +442,8 @@ class Linear1D_Row(ParallelModule):
|
||||
ring=True,
|
||||
)
|
||||
else:
|
||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
||||
output = reduce_forward(output_parallel, self.process_group)
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
||||
|
||||
if not self.skip_bias_add:
|
||||
if self.bias is not None:
|
||||
|
Reference in New Issue
Block a user