mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
[zero] zero optim state_dict takes only_rank_0 (#1384)
* zero optim state_dict takes only_rank_0 * fix unit test
This commit is contained in:
@@ -193,8 +193,9 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
if isinstance(val, torch.Tensor):
|
||||
self.chunk_manager.add_extern_static_tensor(val)
|
||||
|
||||
def state_dict(self):
|
||||
r"""Returns the state of the optimizer as a :class:`dict`. For DP rank != 0, this function returns None.
|
||||
def state_dict(self, only_rank_0: bool = True):
|
||||
r"""Returns the state of the optimizer as a :class:`dict`. If only_rank_0 is True, for DP rank != 0, this function returns None.
|
||||
This saves memory usage.
|
||||
|
||||
It contains two entries:
|
||||
|
||||
@@ -204,7 +205,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
parameter group is a dict
|
||||
"""
|
||||
is_rank_0 = self.chunk_manager.process_group.dp_local_rank() == 0
|
||||
if not self.chunk_manager.enable_distributed_storage and not is_rank_0:
|
||||
if not self.chunk_manager.enable_distributed_storage and only_rank_0 and not is_rank_0:
|
||||
return
|
||||
optim_state_dict = super().state_dict()
|
||||
scaler_state_dict = self.grad_scaler.state_dict()
|
||||
@@ -214,14 +215,17 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
local_state = {k: convert_state_dict_to_cpu(v) for k, v in optim_state_dict['state'].items() if len(v) > 0}
|
||||
if not self.chunk_manager.process_group.has_cpu_groups:
|
||||
self.chunk_manager.process_group.set_cpu_groups()
|
||||
dst_rank = self.chunk_manager.process_group.dp_rank_list()[0]
|
||||
output = [None for _ in range(self.chunk_manager.process_group.dp_world_size())]
|
||||
dist.gather_object(local_state,
|
||||
output if self.chunk_manager.process_group.dp_local_rank() == 0 else None,
|
||||
dst=dst_rank,
|
||||
group=self.chunk_manager.process_group.cpu_dp_process_group())
|
||||
if not is_rank_0:
|
||||
return
|
||||
if only_rank_0:
|
||||
dst_rank = self.chunk_manager.process_group.dp_rank_list()[0]
|
||||
dist.gather_object(local_state,
|
||||
output if self.chunk_manager.process_group.dp_local_rank() == 0 else None,
|
||||
dst=dst_rank,
|
||||
group=self.chunk_manager.process_group.cpu_dp_process_group())
|
||||
if not is_rank_0:
|
||||
return
|
||||
else:
|
||||
dist.all_gather_object(output, local_state, group=self.chunk_manager.process_group.cpu_dp_process_group())
|
||||
for state in output:
|
||||
optim_state_dict['state'].update(state)
|
||||
return optim_state_dict
|
||||
|
Reference in New Issue
Block a user