diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py index c819cb9a8..edfe07697 100644 --- a/colossalai/nn/layer/parallel_1d/layers.py +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -302,10 +302,8 @@ class Linear1D_Col(ParallelLayer): with seed(ParallelMode.TENSOR): self.reset_parameters(weight_initializer, bias_initializer) self._set_tensor_parallel_attributes() - if self.gather_output: - set_parallel_input(False) - else: - set_parallel_input(True) + is_parallel_output = not self.gather_output + set_parallel_input(is_parallel_output) def reset_parameters(self, weight_initializer, bias_initializer) -> None: fan_in, fan_out = self.in_features, self.out_features