mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[async io]supoort async io (#6137)
* support async optimizer save/load * fix * fix * support pin mem * Update low_level_zero_plugin.py * fix * fix * fix * fix * fix
This commit is contained in:
committed by
Hongxin Liu
parent
b90835bd32
commit
eb69e640e5
@@ -770,7 +770,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
return {"state": packed_state, "param_groups": param_groups}
|
||||
|
||||
def state_dict(self) -> Dict:
|
||||
def state_dict(self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None) -> Dict:
|
||||
"""Return a state_dict same with DDP
|
||||
|
||||
Returns:
|
||||
@@ -779,15 +779,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
zero_state = dict()
|
||||
device = get_accelerator().get_current_device()
|
||||
for param, state in self.optim.state.items():
|
||||
if pinned_state_dicts and param not in pinned_state_dicts:
|
||||
pinned_state_dicts[param] = {}
|
||||
zero_state[param] = copy.deepcopy(state)
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
if pinned_state_dicts and k not in pinned_state_dicts[param]:
|
||||
pinned_state_dicts[param][k] = torch.empty_like(working_param, pin_memory=True, device="cpu")
|
||||
working_param = self.master_to_working_param[id(param)]
|
||||
pg = self.param_to_pg[working_param]
|
||||
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
||||
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
|
||||
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param).cpu()
|
||||
zero_state[param][k] = param_state
|
||||
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param)
|
||||
if pinned_state_dicts:
|
||||
pinned_state_dicts[param][k].copy_(param_state)
|
||||
zero_state[param][k] = pinned_state_dicts[param][k]
|
||||
else:
|
||||
zero_state[param][k] = param_state.cpu()
|
||||
|
||||
states_dict = self._pack_state(zero_state)
|
||||
|
||||
@@ -822,7 +830,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
self.optim.load_state_dict(zero_state_dict)
|
||||
|
||||
def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]:
|
||||
def state_dict_shard(
|
||||
self, max_shard_size: int = 1024, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None
|
||||
) -> Iterator[Tuple[Dict, int]]:
|
||||
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
|
||||
Only include the 'state' in state_dict.
|
||||
|
||||
@@ -847,18 +857,27 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
for param_idx, states in local_states.items():
|
||||
current_block_size = 0
|
||||
current_block = copy.deepcopy(states)
|
||||
|
||||
if pinned_state_dicts and param_idx not in pinned_state_dicts:
|
||||
pinned_state_dicts[param_idx] = {}
|
||||
master_param = idx2master[param_idx]
|
||||
working_param = self.master_to_working_param[id(master_param)]
|
||||
pg = self.param_to_pg[working_param]
|
||||
|
||||
for k, v in states.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
if pinned_state_dicts and k not in pinned_state_dicts[param_idx]:
|
||||
pinned_state_dicts[param_idx][k] = torch.empty_like(
|
||||
working_param, pin_memory=True, device="cpu"
|
||||
)
|
||||
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
||||
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
|
||||
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param).cpu()
|
||||
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param)
|
||||
if pinned_state_dicts:
|
||||
pinned_state_dicts[param_idx][k].copy_(state_tensor)
|
||||
current_block[k] = pinned_state_dicts[param_idx][k]
|
||||
else:
|
||||
current_block[k] = state_tensor.cpu()
|
||||
current_block_size += state_tensor.numel()
|
||||
current_block[k] = state_tensor
|
||||
|
||||
if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
|
||||
yield ret_block, ret_block_size
|
||||
|
Reference in New Issue
Block a user