mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[checkpointio]support asyncio for 3d (#6152)
* fix * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -903,6 +903,7 @@ class GeminiDDP(ModelWrapper):
|
||||
keep_vars: bool = False,
|
||||
max_shard_size: int = 1024,
|
||||
only_rank_0: bool = True,
|
||||
pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
|
||||
|
||||
@@ -943,6 +944,13 @@ class GeminiDDP(ModelWrapper):
|
||||
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0))
|
||||
gathered_param = gathered_param_buffer.pop(param_to_save)
|
||||
|
||||
if pinned_state_dicts is not None:
|
||||
if (prefix + name) not in pinned_state_dicts:
|
||||
pinned_state_dicts[prefix + name] = torch.empty_like(
|
||||
gathered_param, pin_memory=True, device="cpu"
|
||||
)
|
||||
pinned_state_dicts[prefix + name].copy_(gathered_param)
|
||||
gathered_param = pinned_state_dicts[prefix + name]
|
||||
block, block_size = sharder.append_param(prefix + name, gathered_param)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
@@ -954,6 +962,11 @@ class GeminiDDP(ModelWrapper):
|
||||
for name, buf in self.named_buffers():
|
||||
if buf is not None and name not in self._non_persistent_buffers_set:
|
||||
buffer = buf if keep_vars else buf.detach()
|
||||
if pinned_state_dicts is not None:
|
||||
if (prefix + name) not in pinned_state_dicts:
|
||||
pinned_state_dicts[prefix + name] = torch.empty_like(buffer, pin_memory=True, device="cpu")
|
||||
pinned_state_dicts[prefix + name].copy_(buffer)
|
||||
buffer = pinned_state_dicts[prefix + name]
|
||||
block, block_size = sharder.append_param(prefix + name, buffer)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
@@ -964,6 +977,11 @@ class GeminiDDP(ModelWrapper):
|
||||
is not torch.nn.Module.get_extra_state
|
||||
):
|
||||
extra_state = self.get_extra_state()
|
||||
if pinned_state_dicts is not None:
|
||||
if extra_state_key not in pinned_state_dicts:
|
||||
pinned_state_dicts[extra_state_key] = torch.empty_like(extra_state, pin_memory=True, device="cpu")
|
||||
pinned_state_dicts[extra_state_key].copy_(extra_state)
|
||||
extra_state = pinned_state_dicts[extra_state_key]
|
||||
block, block_size = sharder.append_param(extra_state_key, extra_state)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
Reference in New Issue
Block a user