mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-06 12:07:00 +00:00
[gemini]fix gemini optimzer, saving Shardformer in Gemini got list assignment index out of range (#5085)
This commit is contained in:
parent
0d482302a1
commit
4ccb9ded7d
@ -423,8 +423,8 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||||||
param = self.id_to_real_params[param_id]
|
param = self.id_to_real_params[param_id]
|
||||||
fake_param = self.id_to_fake_params.get(param_id, None)
|
fake_param = self.id_to_fake_params.get(param_id, None)
|
||||||
chunk = self.chunk_manager.get_chunk(param)
|
chunk = self.chunk_manager.get_chunk(param)
|
||||||
dp_group = chunk.torch_pg
|
zero_group = chunk.torch_pg
|
||||||
rank = dist.get_rank(dp_group)
|
rank = dist.get_rank(zero_group)
|
||||||
master_rank = 0
|
master_rank = 0
|
||||||
collected_states = {}
|
collected_states = {}
|
||||||
|
|
||||||
@ -432,9 +432,9 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||||||
local_state_names = None
|
local_state_names = None
|
||||||
if fake_param is not None:
|
if fake_param is not None:
|
||||||
local_state_names = list(self.optim.state[fake_param].keys())
|
local_state_names = list(self.optim.state[fake_param].keys())
|
||||||
gathered_state_names = [None for _ in range(dist.get_world_size(dp_group))]
|
gathered_state_names = [None for _ in range(dist.get_world_size(zero_group))]
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
dist.all_gather_object(gathered_state_names, local_state_names, dp_group)
|
dist.all_gather_object(gathered_state_names, local_state_names, zero_group)
|
||||||
state_names = None
|
state_names = None
|
||||||
for names in gathered_state_names:
|
for names in gathered_state_names:
|
||||||
if names is not None:
|
if names is not None:
|
||||||
@ -512,10 +512,10 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||||||
_, shard_offset, shard_size = self.get_offsets(param_id)
|
_, shard_offset, shard_size = self.get_offsets(param_id)
|
||||||
|
|
||||||
# Collectors gather state shards through all_gathering.
|
# Collectors gather state shards through all_gathering.
|
||||||
gathered_state_shards = [None for _ in range(dist.get_world_size(dp_group))]
|
gathered_state_shards = [None for _ in range(dist.get_world_size(zero_group))]
|
||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size])
|
dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size], group=zero_group)
|
||||||
|
|
||||||
if is_collector:
|
if is_collector:
|
||||||
for state_shard in gathered_state_shards:
|
for state_shard in gathered_state_shards:
|
||||||
|
Loading…
Reference in New Issue
Block a user