[checkpointio]support asyncio for 3d (#6152)

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2024-12-23 10:24:22 +08:00
committed by GitHub
parent aaafb38851
commit 130229fdcb
17 changed files with 776 additions and 188 deletions

View File

@@ -903,6 +903,7 @@ class GeminiDDP(ModelWrapper):
keep_vars: bool = False,
max_shard_size: int = 1024,
only_rank_0: bool = True,
pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None,
) -> Iterator[Tuple[OrderedDict, int]]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
@@ -943,6 +944,13 @@ class GeminiDDP(ModelWrapper):
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0))
gathered_param = gathered_param_buffer.pop(param_to_save)
if pinned_state_dicts is not None:
if (prefix + name) not in pinned_state_dicts:
pinned_state_dicts[prefix + name] = torch.empty_like(
gathered_param, pin_memory=True, device="cpu"
)
pinned_state_dicts[prefix + name].copy_(gathered_param)
gathered_param = pinned_state_dicts[prefix + name]
block, block_size = sharder.append_param(prefix + name, gathered_param)
if block is not None:
yield block, block_size
@@ -954,6 +962,11 @@ class GeminiDDP(ModelWrapper):
for name, buf in self.named_buffers():
if buf is not None and name not in self._non_persistent_buffers_set:
buffer = buf if keep_vars else buf.detach()
if pinned_state_dicts is not None:
if (prefix + name) not in pinned_state_dicts:
pinned_state_dicts[prefix + name] = torch.empty_like(buffer, pin_memory=True, device="cpu")
pinned_state_dicts[prefix + name].copy_(buffer)
buffer = pinned_state_dicts[prefix + name]
block, block_size = sharder.append_param(prefix + name, buffer)
if block is not None:
yield block, block_size
@@ -964,6 +977,11 @@ class GeminiDDP(ModelWrapper):
is not torch.nn.Module.get_extra_state
):
extra_state = self.get_extra_state()
if pinned_state_dicts is not None:
if extra_state_key not in pinned_state_dicts:
pinned_state_dicts[extra_state_key] = torch.empty_like(extra_state, pin_memory=True, device="cpu")
pinned_state_dicts[extra_state_key].copy_(extra_state)
extra_state = pinned_state_dicts[extra_state_key]
block, block_size = sharder.append_param(extra_state_key, extra_state)
if block is not None:
yield block, block_size