diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 085609c01..d77dd4965 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -202,21 +202,21 @@ class Linear1D_Col(ParallelModule): # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - if self.seq_parallel_mode is None: - output_parallel = linear_with_async_comm( - input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication - ) - elif self.seq_parallel_mode == "split_gather": + if self.seq_parallel_mode == "split_gather": input_parallel = gather_forward_reducescatter_backward( input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication ) - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False) + output_parallel = linear_with_async_comm( + input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication + ) elif self.seq_parallel_mode == "ring": output_parallel = linear_gather_forward_reducescatter_backward( input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True ) else: - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + output_parallel = linear_with_async_comm( + input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication + ) if self.gather_output: # All-gather across the partitions.