From 33530425259d995eec8aa32b62b1a1be06511254 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Mon, 19 Aug 2024 08:07:51 +0000 Subject: [PATCH] fix the merge --- colossalai/zero/low_level/low_level_optim.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 090436cbc..12ef9252b 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -588,9 +588,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): self.pg_to_tensor_bucket[pg].all_gather(pg, fp8_communication=self._fp8_communication) self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] - for pg, tensor_bucket in self.pg_to_tensor_bucket.items(): - if not tensor_bucket.is_empty(): - tensor_bucket.all_gather(pg) + if not self._overlap_allgather: + for pg, tensor_bucket in self.pg_to_tensor_bucket.items(): + if not tensor_bucket.is_empty(): + tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication) def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float: r"""