[hotfix] fix zero optim save (#6191)

This commit is contained in:
Hongxin Liu 2025-02-14 15:09:50 +08:00 committed by GitHub
parent 014837e725
commit 5ff5323538
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -786,30 +786,36 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
""" """
zero_state = dict() zero_state = dict()
device = get_accelerator().get_current_device() device = get_accelerator().get_current_device()
for param, state in self.optim.state.items(): for param_group in self.optim.param_groups:
working_param = self.master_to_working_param[id(param)] for param in param_group["params"]:
pg = self.param_to_pg[working_param] if param not in self.optim.state:
if not only_on_master or get_nd_rank(pg) == 0: continue
zero_state[param] = copy.deepcopy(state) state = self.optim.state[param]
else: working_param = self.master_to_working_param[id(param)]
zero_state[param] = {} pg = self.param_to_pg[working_param]
if not only_on_master or get_nd_rank(pg) == 0:
zero_state[param] = copy.deepcopy(state)
else:
zero_state[param] = {}
if pinned_state_dicts is not None and param not in pinned_state_dicts: if pinned_state_dicts is not None and param not in pinned_state_dicts:
pinned_state_dicts[param] = {} pinned_state_dicts[param] = {}
for k, v in state.items(): for k, v in state.items():
if isinstance(v, torch.Tensor) and k != "step": if isinstance(v, torch.Tensor) and k != "step":
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg) all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param) param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param)
if not only_on_master or get_nd_rank(pg) == 0: if not only_on_master or get_nd_rank(pg) == 0:
if pinned_state_dicts is not None and k not in pinned_state_dicts[param]: if pinned_state_dicts is not None and k not in pinned_state_dicts[param]:
pinned_state_dicts[param][k] = torch.empty_like(param_state, pin_memory=True, device="cpu") pinned_state_dicts[param][k] = torch.empty_like(
if pinned_state_dicts is not None: param_state, pin_memory=True, device="cpu"
pinned_state_dicts[param][k].copy_(param_state) )
zero_state[param][k] = pinned_state_dicts[param][k] if pinned_state_dicts is not None:
else: pinned_state_dicts[param][k].copy_(param_state)
zero_state[param][k] = param_state.cpu() zero_state[param][k] = pinned_state_dicts[param][k]
else:
zero_state[param][k] = param_state.cpu()
states_dict = self._pack_state(zero_state) states_dict = self._pack_state(zero_state)
@ -865,48 +871,52 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
device = get_accelerator().get_current_device() device = get_accelerator().get_current_device()
local_states = self.optim.state_dict()["state"] local_states = self.optim.state_dict()["state"]
idx2master = {} master2idx = {}
cnt = 0 cnt = 0
for param_group in self.optim.param_groups: for param_group in self.optim.param_groups:
for param in param_group["params"]: for param in param_group["params"]:
idx2master[cnt] = param master2idx[param] = cnt
cnt += 1 cnt += 1
for param_idx, states in local_states.items():
current_block_size = 0
if pinned_state_dicts is not None and param_idx not in pinned_state_dicts:
pinned_state_dicts[param_idx] = {}
master_param = idx2master[param_idx]
working_param = self.master_to_working_param[id(master_param)]
pg = self.param_to_pg[working_param]
if not only_on_master or get_nd_rank(pg) == 0:
current_block = copy.deepcopy(states)
else:
current_block = {}
for k, v in states.items(): for param_group in self.optim.param_groups:
if isinstance(v, torch.Tensor) and k != "step": for master_param in param_group["params"]:
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) param_idx = master2idx[master_param]
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg) states = local_states[param_idx]
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param)
if not only_on_master or get_nd_rank(pg) == 0:
if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]:
pinned_state_dicts[param_idx][k] = torch.empty_like(
state_tensor, pin_memory=True, device="cpu"
)
if pinned_state_dicts is not None:
pinned_state_dicts[param_idx][k].copy_(state_tensor)
current_block[k] = pinned_state_dicts[param_idx][k]
else:
current_block[k] = state_tensor.cpu()
current_block_size += calculate_tensor_size(state_tensor)
if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0: current_block_size = 0
yield ret_block, ret_block_size if pinned_state_dicts is not None and param_idx not in pinned_state_dicts:
ret_block = dict() pinned_state_dicts[param_idx] = {}
ret_block_size = 0 working_param = self.master_to_working_param[id(master_param)]
pg = self.param_to_pg[working_param]
if not only_on_master or get_nd_rank(pg) == 0:
current_block = copy.deepcopy(states)
else:
current_block = {}
ret_block[param_idx] = current_block for k, v in states.items():
ret_block_size += current_block_size if isinstance(v, torch.Tensor) and k != "step":
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param)
if not only_on_master or get_nd_rank(pg) == 0:
if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]:
pinned_state_dicts[param_idx][k] = torch.empty_like(
state_tensor, pin_memory=True, device="cpu"
)
if pinned_state_dicts is not None:
pinned_state_dicts[param_idx][k].copy_(state_tensor)
current_block[k] = pinned_state_dicts[param_idx][k]
else:
current_block[k] = state_tensor.cpu()
current_block_size += calculate_tensor_size(state_tensor)
if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
yield ret_block, ret_block_size
ret_block = dict()
ret_block_size = 0
ret_block[param_idx] = current_block
ret_block_size += current_block_size
yield ret_block, ret_block_size yield ret_block, ret_block_size