diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 50d4f51d3..8f828bd6c 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -423,8 +423,8 @@ class GeminiOptimizer(OptimizerWrapper): param = self.id_to_real_params[param_id] fake_param = self.id_to_fake_params.get(param_id, None) chunk = self.chunk_manager.get_chunk(param) - dp_group = chunk.torch_pg - rank = dist.get_rank(dp_group) + zero_group = chunk.torch_pg + rank = dist.get_rank(zero_group) master_rank = 0 collected_states = {} @@ -432,9 +432,9 @@ class GeminiOptimizer(OptimizerWrapper): local_state_names = None if fake_param is not None: local_state_names = list(self.optim.state[fake_param].keys()) - gathered_state_names = [None for _ in range(dist.get_world_size(dp_group))] + gathered_state_names = [None for _ in range(dist.get_world_size(zero_group))] dist.barrier() - dist.all_gather_object(gathered_state_names, local_state_names, dp_group) + dist.all_gather_object(gathered_state_names, local_state_names, zero_group) state_names = None for names in gathered_state_names: if names is not None: @@ -512,10 +512,10 @@ class GeminiOptimizer(OptimizerWrapper): _, shard_offset, shard_size = self.get_offsets(param_id) # Collectors gather state shards through all_gathering. - gathered_state_shards = [None for _ in range(dist.get_world_size(dp_group))] + gathered_state_shards = [None for _ in range(dist.get_world_size(zero_group))] dist.barrier() - dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size]) + dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size], group=zero_group) if is_collector: for state_shard in gathered_state_shards: