[hotfix] fix zero optim save (#6191)

This commit is contained in:
Hongxin Liu 2025-02-14 15:09:50 +08:00 committed by GitHub
parent 014837e725
commit 5ff5323538
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -786,7 +786,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
"""
zero_state = dict()
device = get_accelerator().get_current_device()
for param, state in self.optim.state.items():
for param_group in self.optim.param_groups:
for param in param_group["params"]:
if param not in self.optim.state:
continue
state = self.optim.state[param]
working_param = self.master_to_working_param[id(param)]
pg = self.param_to_pg[working_param]
if not only_on_master or get_nd_rank(pg) == 0:
@ -804,7 +808,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param)
if not only_on_master or get_nd_rank(pg) == 0:
if pinned_state_dicts is not None and k not in pinned_state_dicts[param]:
pinned_state_dicts[param][k] = torch.empty_like(param_state, pin_memory=True, device="cpu")
pinned_state_dicts[param][k] = torch.empty_like(
param_state, pin_memory=True, device="cpu"
)
if pinned_state_dicts is not None:
pinned_state_dicts[param][k].copy_(param_state)
zero_state[param][k] = pinned_state_dicts[param][k]
@ -865,17 +871,21 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
device = get_accelerator().get_current_device()
local_states = self.optim.state_dict()["state"]
idx2master = {}
master2idx = {}
cnt = 0
for param_group in self.optim.param_groups:
for param in param_group["params"]:
idx2master[cnt] = param
master2idx[param] = cnt
cnt += 1
for param_idx, states in local_states.items():
for param_group in self.optim.param_groups:
for master_param in param_group["params"]:
param_idx = master2idx[master_param]
states = local_states[param_idx]
current_block_size = 0
if pinned_state_dicts is not None and param_idx not in pinned_state_dicts:
pinned_state_dicts[param_idx] = {}
master_param = idx2master[param_idx]
working_param = self.master_to_working_param[id(master_param)]
pg = self.param_to_pg[working_param]
if not only_on_master or get_nd_rank(pg) == 0: