mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[checkpointio]support asyncio for 3d (#6152)
* fix * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -71,6 +71,8 @@ def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[d
|
||||
|
||||
for idx, d in states.items():
|
||||
for k, v in d.items():
|
||||
if v is None:
|
||||
continue
|
||||
nested_key = f"state{seperator}{idx}{seperator}{k}"
|
||||
if not isinstance(v, torch.Tensor):
|
||||
non_tensor_keys.append(nested_key)
|
||||
@@ -87,7 +89,8 @@ def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[d
|
||||
|
||||
def _unflatten_optim_state_dict(flat_dict: dict, metadata: Optional[dict] = None, seperator: str = "."):
|
||||
state_dict = {}
|
||||
if metadata is not None:
|
||||
|
||||
if metadata is not None and "non_tensor_keys" in metadata:
|
||||
non_tensor_keys = json.loads(metadata["non_tensor_keys"])
|
||||
else:
|
||||
non_tensor_keys = []
|
||||
@@ -128,8 +131,10 @@ def prepare(
|
||||
header = {}
|
||||
offset = 0
|
||||
|
||||
header_metadata = {"format": "pt"}
|
||||
if metadata is not None:
|
||||
header["__metadata__"] = metadata
|
||||
header_metadata.update(metadata)
|
||||
header["__metadata__"] = header_metadata
|
||||
|
||||
for name, tensor in data.items():
|
||||
n = tensor.numel() * tensor.element_size()
|
||||
@@ -172,8 +177,9 @@ def move_and_save(
|
||||
path: str,
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None,
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
prepared_data, _, tensor_keys = prepare(state_dict)
|
||||
prepared_data, _, tensor_keys = prepare(state_dict, metadata)
|
||||
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
|
||||
f_writer = AsyncFileWriter(path, n_entries=ASYNC_WRITE_ENTRIES, backend="pthread", n_tasks=2 + len(tensor_keys))
|
||||
f_writer.write(n.to_bytes(8, byteorder="little"))
|
||||
@@ -188,9 +194,9 @@ def move_and_save(
|
||||
return f_writer
|
||||
|
||||
|
||||
def load_flat(checkpoint_path):
|
||||
def load_flat(checkpoint_path, seperator: str = "."):
|
||||
with safe_open(checkpoint_path, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
state_dict_load = load_file(checkpoint_path)
|
||||
state_dict = _unflatten_optim_state_dict(state_dict_load, metadata)
|
||||
state_dict = _unflatten_optim_state_dict(state_dict_load, metadata, seperator)
|
||||
return state_dict
|
||||
|
Reference in New Issue
Block a user