[checkpointio]support asyncio for 3d (#6152)

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2024-12-23 10:24:22 +08:00
committed by GitHub
parent aaafb38851
commit 130229fdcb
17 changed files with 776 additions and 188 deletions

View File

@@ -71,6 +71,8 @@ def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[d
for idx, d in states.items():
for k, v in d.items():
if v is None:
continue
nested_key = f"state{seperator}{idx}{seperator}{k}"
if not isinstance(v, torch.Tensor):
non_tensor_keys.append(nested_key)
@@ -87,7 +89,8 @@ def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[d
def _unflatten_optim_state_dict(flat_dict: dict, metadata: Optional[dict] = None, seperator: str = "."):
state_dict = {}
if metadata is not None:
if metadata is not None and "non_tensor_keys" in metadata:
non_tensor_keys = json.loads(metadata["non_tensor_keys"])
else:
non_tensor_keys = []
@@ -128,8 +131,10 @@ def prepare(
header = {}
offset = 0
header_metadata = {"format": "pt"}
if metadata is not None:
header["__metadata__"] = metadata
header_metadata.update(metadata)
header["__metadata__"] = header_metadata
for name, tensor in data.items():
n = tensor.numel() * tensor.element_size()
@@ -172,8 +177,9 @@ def move_and_save(
path: str,
state_dict: Dict[str, torch.Tensor],
state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None,
metadata: Optional[Dict[str, str]] = None,
) -> None:
prepared_data, _, tensor_keys = prepare(state_dict)
prepared_data, _, tensor_keys = prepare(state_dict, metadata)
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
f_writer = AsyncFileWriter(path, n_entries=ASYNC_WRITE_ENTRIES, backend="pthread", n_tasks=2 + len(tensor_keys))
f_writer.write(n.to_bytes(8, byteorder="little"))
@@ -188,9 +194,9 @@ def move_and_save(
return f_writer
def load_flat(checkpoint_path):
def load_flat(checkpoint_path, seperator: str = "."):
with safe_open(checkpoint_path, framework="pt") as f:
metadata = f.metadata()
state_dict_load = load_file(checkpoint_path)
state_dict = _unflatten_optim_state_dict(state_dict_load, metadata)
state_dict = _unflatten_optim_state_dict(state_dict_load, metadata, seperator)
return state_dict