mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-27 11:31:58 +00:00
[hotfix] fix zero optim save (#6191)
This commit is contained in:
parent
014837e725
commit
5ff5323538
@ -786,30 +786,36 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
"""
|
"""
|
||||||
zero_state = dict()
|
zero_state = dict()
|
||||||
device = get_accelerator().get_current_device()
|
device = get_accelerator().get_current_device()
|
||||||
for param, state in self.optim.state.items():
|
for param_group in self.optim.param_groups:
|
||||||
working_param = self.master_to_working_param[id(param)]
|
for param in param_group["params"]:
|
||||||
pg = self.param_to_pg[working_param]
|
if param not in self.optim.state:
|
||||||
if not only_on_master or get_nd_rank(pg) == 0:
|
continue
|
||||||
zero_state[param] = copy.deepcopy(state)
|
state = self.optim.state[param]
|
||||||
else:
|
working_param = self.master_to_working_param[id(param)]
|
||||||
zero_state[param] = {}
|
pg = self.param_to_pg[working_param]
|
||||||
|
if not only_on_master or get_nd_rank(pg) == 0:
|
||||||
|
zero_state[param] = copy.deepcopy(state)
|
||||||
|
else:
|
||||||
|
zero_state[param] = {}
|
||||||
|
|
||||||
if pinned_state_dicts is not None and param not in pinned_state_dicts:
|
if pinned_state_dicts is not None and param not in pinned_state_dicts:
|
||||||
pinned_state_dicts[param] = {}
|
pinned_state_dicts[param] = {}
|
||||||
|
|
||||||
for k, v in state.items():
|
for k, v in state.items():
|
||||||
if isinstance(v, torch.Tensor) and k != "step":
|
if isinstance(v, torch.Tensor) and k != "step":
|
||||||
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
||||||
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
|
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
|
||||||
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param)
|
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param)
|
||||||
if not only_on_master or get_nd_rank(pg) == 0:
|
if not only_on_master or get_nd_rank(pg) == 0:
|
||||||
if pinned_state_dicts is not None and k not in pinned_state_dicts[param]:
|
if pinned_state_dicts is not None and k not in pinned_state_dicts[param]:
|
||||||
pinned_state_dicts[param][k] = torch.empty_like(param_state, pin_memory=True, device="cpu")
|
pinned_state_dicts[param][k] = torch.empty_like(
|
||||||
if pinned_state_dicts is not None:
|
param_state, pin_memory=True, device="cpu"
|
||||||
pinned_state_dicts[param][k].copy_(param_state)
|
)
|
||||||
zero_state[param][k] = pinned_state_dicts[param][k]
|
if pinned_state_dicts is not None:
|
||||||
else:
|
pinned_state_dicts[param][k].copy_(param_state)
|
||||||
zero_state[param][k] = param_state.cpu()
|
zero_state[param][k] = pinned_state_dicts[param][k]
|
||||||
|
else:
|
||||||
|
zero_state[param][k] = param_state.cpu()
|
||||||
|
|
||||||
states_dict = self._pack_state(zero_state)
|
states_dict = self._pack_state(zero_state)
|
||||||
|
|
||||||
@ -865,48 +871,52 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||||||
device = get_accelerator().get_current_device()
|
device = get_accelerator().get_current_device()
|
||||||
local_states = self.optim.state_dict()["state"]
|
local_states = self.optim.state_dict()["state"]
|
||||||
|
|
||||||
idx2master = {}
|
master2idx = {}
|
||||||
cnt = 0
|
cnt = 0
|
||||||
for param_group in self.optim.param_groups:
|
for param_group in self.optim.param_groups:
|
||||||
for param in param_group["params"]:
|
for param in param_group["params"]:
|
||||||
idx2master[cnt] = param
|
master2idx[param] = cnt
|
||||||
cnt += 1
|
cnt += 1
|
||||||
for param_idx, states in local_states.items():
|
|
||||||
current_block_size = 0
|
|
||||||
if pinned_state_dicts is not None and param_idx not in pinned_state_dicts:
|
|
||||||
pinned_state_dicts[param_idx] = {}
|
|
||||||
master_param = idx2master[param_idx]
|
|
||||||
working_param = self.master_to_working_param[id(master_param)]
|
|
||||||
pg = self.param_to_pg[working_param]
|
|
||||||
if not only_on_master or get_nd_rank(pg) == 0:
|
|
||||||
current_block = copy.deepcopy(states)
|
|
||||||
else:
|
|
||||||
current_block = {}
|
|
||||||
|
|
||||||
for k, v in states.items():
|
for param_group in self.optim.param_groups:
|
||||||
if isinstance(v, torch.Tensor) and k != "step":
|
for master_param in param_group["params"]:
|
||||||
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
param_idx = master2idx[master_param]
|
||||||
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
|
states = local_states[param_idx]
|
||||||
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param)
|
|
||||||
if not only_on_master or get_nd_rank(pg) == 0:
|
|
||||||
if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]:
|
|
||||||
pinned_state_dicts[param_idx][k] = torch.empty_like(
|
|
||||||
state_tensor, pin_memory=True, device="cpu"
|
|
||||||
)
|
|
||||||
if pinned_state_dicts is not None:
|
|
||||||
pinned_state_dicts[param_idx][k].copy_(state_tensor)
|
|
||||||
current_block[k] = pinned_state_dicts[param_idx][k]
|
|
||||||
else:
|
|
||||||
current_block[k] = state_tensor.cpu()
|
|
||||||
current_block_size += calculate_tensor_size(state_tensor)
|
|
||||||
|
|
||||||
if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
|
current_block_size = 0
|
||||||
yield ret_block, ret_block_size
|
if pinned_state_dicts is not None and param_idx not in pinned_state_dicts:
|
||||||
ret_block = dict()
|
pinned_state_dicts[param_idx] = {}
|
||||||
ret_block_size = 0
|
working_param = self.master_to_working_param[id(master_param)]
|
||||||
|
pg = self.param_to_pg[working_param]
|
||||||
|
if not only_on_master or get_nd_rank(pg) == 0:
|
||||||
|
current_block = copy.deepcopy(states)
|
||||||
|
else:
|
||||||
|
current_block = {}
|
||||||
|
|
||||||
ret_block[param_idx] = current_block
|
for k, v in states.items():
|
||||||
ret_block_size += current_block_size
|
if isinstance(v, torch.Tensor) and k != "step":
|
||||||
|
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
||||||
|
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
|
||||||
|
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param)
|
||||||
|
if not only_on_master or get_nd_rank(pg) == 0:
|
||||||
|
if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]:
|
||||||
|
pinned_state_dicts[param_idx][k] = torch.empty_like(
|
||||||
|
state_tensor, pin_memory=True, device="cpu"
|
||||||
|
)
|
||||||
|
if pinned_state_dicts is not None:
|
||||||
|
pinned_state_dicts[param_idx][k].copy_(state_tensor)
|
||||||
|
current_block[k] = pinned_state_dicts[param_idx][k]
|
||||||
|
else:
|
||||||
|
current_block[k] = state_tensor.cpu()
|
||||||
|
current_block_size += calculate_tensor_size(state_tensor)
|
||||||
|
|
||||||
|
if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
|
||||||
|
yield ret_block, ret_block_size
|
||||||
|
ret_block = dict()
|
||||||
|
ret_block_size = 0
|
||||||
|
|
||||||
|
ret_block[param_idx] = current_block
|
||||||
|
ret_block_size += current_block_size
|
||||||
|
|
||||||
yield ret_block, ret_block_size
|
yield ret_block, ret_block_size
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user