mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-23 14:10:29 +00:00
fix the merge
This commit is contained in:
parent
2eb36839c6
commit
88b3f0698c
@ -202,21 +202,21 @@ class Linear1D_Col(ParallelModule):
|
|||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
|
|
||||||
if self.seq_parallel_mode is None:
|
if self.seq_parallel_mode == "split_gather":
|
||||||
output_parallel = linear_with_async_comm(
|
|
||||||
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
|
|
||||||
)
|
|
||||||
elif self.seq_parallel_mode == "split_gather":
|
|
||||||
input_parallel = gather_forward_reducescatter_backward(
|
input_parallel = gather_forward_reducescatter_backward(
|
||||||
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||||
)
|
)
|
||||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False)
|
output_parallel = linear_with_async_comm(
|
||||||
|
input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication
|
||||||
|
)
|
||||||
elif self.seq_parallel_mode == "ring":
|
elif self.seq_parallel_mode == "ring":
|
||||||
output_parallel = linear_gather_forward_reducescatter_backward(
|
output_parallel = linear_gather_forward_reducescatter_backward(
|
||||||
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
|
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
output_parallel = linear_with_async_comm(
|
||||||
|
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
|
||||||
|
)
|
||||||
|
|
||||||
if self.gather_output:
|
if self.gather_output:
|
||||||
# All-gather across the partitions.
|
# All-gather across the partitions.
|
||||||
|
Loading…
Reference in New Issue
Block a user