[checkpointio] fix zero optimizer async save memory (#6151)

* [checkpointio] fix zero optimizer async save memory

* [checkpointio] fit new tensornvme api

* [checkpointio] fit new tensornvme api
This commit is contained in:
Hongxin Liu
2024-11-25 14:46:31 +08:00
committed by GitHub
parent 8ecff0cb7f
commit ab856fd308
7 changed files with 57 additions and 42 deletions

View File

@@ -776,7 +776,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
return {"state": packed_state, "param_groups": param_groups}
def state_dict(self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None) -> Dict:
def state_dict(
self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None, only_on_master: bool = False
) -> Dict:
"""Return a state_dict same with DDP
Returns:
@@ -785,23 +787,29 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
zero_state = dict()
device = get_accelerator().get_current_device()
for param, state in self.optim.state.items():
working_param = self.master_to_working_param[id(param)]
pg = self.param_to_pg[working_param]
if not only_on_master or get_nd_rank(pg) == 0:
zero_state[param] = copy.deepcopy(state)
else:
zero_state[param] = {}
if pinned_state_dicts is not None and param not in pinned_state_dicts:
pinned_state_dicts[param] = {}
zero_state[param] = copy.deepcopy(state)
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != "step":
working_param = self.master_to_working_param[id(param)]
pg = self.param_to_pg[working_param]
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param)
if pinned_state_dicts is not None and k not in pinned_state_dicts[param]:
pinned_state_dicts[param][k] = torch.empty_like(param_state, pin_memory=True, device="cpu")
if pinned_state_dicts is not None:
pinned_state_dicts[param][k].copy_(param_state)
zero_state[param][k] = pinned_state_dicts[param][k]
else:
zero_state[param][k] = param_state.cpu()
if not only_on_master or get_nd_rank(pg) == 0:
if pinned_state_dicts is not None and k not in pinned_state_dicts[param]:
pinned_state_dicts[param][k] = torch.empty_like(param_state, pin_memory=True, device="cpu")
if pinned_state_dicts is not None:
pinned_state_dicts[param][k].copy_(param_state)
zero_state[param][k] = pinned_state_dicts[param][k]
else:
zero_state[param][k] = param_state.cpu()
states_dict = self._pack_state(zero_state)
@@ -837,7 +845,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self.optim.load_state_dict(zero_state_dict)
def state_dict_shard(
self, max_shard_size: int = 1024, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None
self,
max_shard_size: int = 1024,
pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None,
only_on_master: bool = False,
) -> Iterator[Tuple[Dict, int]]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
Only include the 'state' in state_dict.
@@ -862,25 +873,31 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
cnt += 1
for param_idx, states in local_states.items():
current_block_size = 0
current_block = copy.deepcopy(states)
if pinned_state_dicts is not None and param_idx not in pinned_state_dicts:
pinned_state_dicts[param_idx] = {}
master_param = idx2master[param_idx]
working_param = self.master_to_working_param[id(master_param)]
pg = self.param_to_pg[working_param]
if not only_on_master or get_nd_rank(pg) == 0:
current_block = copy.deepcopy(states)
else:
current_block = {}
for k, v in states.items():
if isinstance(v, torch.Tensor) and k != "step":
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param)
if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]:
pinned_state_dicts[param_idx][k] = torch.empty_like(state_tensor, pin_memory=True, device="cpu")
if pinned_state_dicts is not None:
pinned_state_dicts[param_idx][k].copy_(state_tensor)
current_block[k] = pinned_state_dicts[param_idx][k]
else:
current_block[k] = state_tensor.cpu()
if not only_on_master or get_nd_rank(pg) == 0:
if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]:
pinned_state_dicts[param_idx][k] = torch.empty_like(
state_tensor, pin_memory=True, device="cpu"
)
if pinned_state_dicts is not None:
pinned_state_dicts[param_idx][k].copy_(state_tensor)
current_block[k] = pinned_state_dicts[param_idx][k]
else:
current_block[k] = state_tensor.cpu()
current_block_size += calculate_tensor_size(state_tensor)
if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0: