diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 3af2e70d6..d39d6e997 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -353,9 +353,14 @@ class Linear1D_Col(ParallelModule): ) else: output_parallel = linear_with_async_comm( - input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication + input_parallel, + self.weight, + bias, + self.process_group, + True, + fp8_communication=self.fp8_communication, + use_zbv=self.use_zbv, ) - if self.gather_output: # All-gather across the partitions. output = gather_forward_split_backward(