mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[async io]supoort async io (#6137)
* support async optimizer save/load * fix * fix * support pin mem * Update low_level_zero_plugin.py * fix * fix * fix * fix * fix
This commit is contained in:
committed by
Hongxin Liu
parent
b90835bd32
commit
eb69e640e5
@@ -1,10 +1,11 @@
|
||||
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
|
||||
import json
|
||||
import warnings
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES
|
||||
from safetensors.torch import _TYPES, load_file, safe_open
|
||||
|
||||
try:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
@@ -27,34 +28,93 @@ class PreparedData:
|
||||
offset: int
|
||||
|
||||
|
||||
def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor], List[str]]:
|
||||
def flatten_dict(nested_dict, parent_key="", separator="^"):
|
||||
"""
|
||||
Flatten a nested dictionary, generating a flattened dictionary where the keys are joined by the specified separator.
|
||||
|
||||
nested_dict: The input nested dictionary.
|
||||
parent_key: The parent key currently being processed.
|
||||
separator: The separator used to join keys, default is '_', but can be customized to another symbol. :return: A flattened dictionary."
|
||||
"""
|
||||
items = []
|
||||
for k, v in nested_dict.items():
|
||||
new_key = f"{parent_key}{separator}{k}" if parent_key else str(k)
|
||||
if isinstance(v, dict):
|
||||
items.extend(flatten_dict(v, new_key, separator).items())
|
||||
else:
|
||||
v = torch.tensor(v, dtype=torch.float16) if not isinstance(v, torch.Tensor) else v
|
||||
items.append((new_key, v))
|
||||
|
||||
return dict(items)
|
||||
|
||||
|
||||
def unflatten_dict(flattened_dict, separator="^"):
|
||||
"""
|
||||
Restore a flattened dictionary back to a multi-level nested dictionary.
|
||||
|
||||
flattened_dict: The flattened dictionary.
|
||||
separator: The separator used during flattening, default is '_', but can be customized to another symbol. :return: The restored nested dictionary.
|
||||
"""
|
||||
nested_dict = {}
|
||||
for key, value in flattened_dict.items():
|
||||
keys = key.split(separator)
|
||||
try:
|
||||
keys[0] = int(keys[0])
|
||||
except ValueError:
|
||||
warnings.warn(f"{key[0]} can't convert to integer")
|
||||
d = nested_dict
|
||||
for part in keys[:-1]:
|
||||
if part not in d:
|
||||
d[part] = {}
|
||||
d = d[part]
|
||||
assert isinstance(value, torch.Tensor)
|
||||
d[keys[-1]] = value
|
||||
|
||||
return nested_dict
|
||||
|
||||
|
||||
def prepare(
|
||||
data: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
|
||||
) -> Tuple[PreparedData, List[torch.Tensor], List[str]]:
|
||||
if metadata is not None:
|
||||
assert isinstance(metadata, dict)
|
||||
for k, v in metadata.items():
|
||||
metadata[k] = json.dumps(v)
|
||||
assert isinstance(k, str)
|
||||
assert isinstance(metadata[k], str)
|
||||
|
||||
tensors = []
|
||||
tensor_keys = []
|
||||
metadata = {}
|
||||
header = {}
|
||||
offset = 0
|
||||
|
||||
if metadata is not None:
|
||||
header["__metadata__"] = metadata
|
||||
|
||||
for name, tensor in data.items():
|
||||
n = tensor.numel() * tensor.element_size()
|
||||
tensor_info = TensorInfo(
|
||||
dtype=_TYPES_INV[tensor.dtype], shape=list(tensor.shape), data_offsets=(offset, offset + n)
|
||||
)
|
||||
offset += n
|
||||
metadata[name] = asdict(tensor_info)
|
||||
header[name] = asdict(tensor_info)
|
||||
tensors.append(tensor)
|
||||
tensor_keys.append(name)
|
||||
|
||||
metadata_buf = json.dumps(metadata).encode("utf-8")
|
||||
header_buf = json.dumps(header).encode("utf-8")
|
||||
|
||||
extra = (8 - len(metadata_buf) % 8) % 8
|
||||
metadata_buf += b" " * extra
|
||||
extra = (8 - len(header_buf) % 8) % 8
|
||||
header_buf += b" " * extra
|
||||
|
||||
n = len(metadata_buf)
|
||||
n = len(header_buf)
|
||||
|
||||
return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors, tensor_keys
|
||||
return PreparedData(n=n, header_bytes=header_buf, offset=offset), tensors, tensor_keys
|
||||
|
||||
|
||||
def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
|
||||
prepared_data, tensors, _ = prepare(state_dict)
|
||||
def save(
|
||||
f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
|
||||
) -> None:
|
||||
prepared_data, tensors, _ = prepare(state_dict, metadata)
|
||||
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
|
||||
|
||||
f_writer.write(n.to_bytes(8, byteorder="little"))
|
||||
@@ -64,6 +124,13 @@ def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None
|
||||
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
|
||||
|
||||
|
||||
def save_nested(
|
||||
f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
|
||||
) -> None:
|
||||
flatten_data = flatten_dict(state_dict)
|
||||
save(f_writer, flatten_data, metadata)
|
||||
|
||||
|
||||
def move_and_save(
|
||||
f_writer: AsyncFileWriter,
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
@@ -81,3 +148,16 @@ def move_and_save(
|
||||
f_writer.write_tensor(state_dict[name], state_dict_pinned[name])
|
||||
else:
|
||||
f_writer.write_tensor(state_dict[name])
|
||||
|
||||
|
||||
def load_flat(checkpoint_path):
|
||||
with safe_open(checkpoint_path, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
state_dict_load = load_file(checkpoint_path)
|
||||
state_dict = unflatten_dict(state_dict_load)
|
||||
if metadata is None:
|
||||
return state_dict
|
||||
metadata = dict(map(lambda item: (item[0], json.loads(item[1])), metadata.items()))
|
||||
combined_state_dict = {"state": state_dict}
|
||||
combined_state_dict.update(metadata)
|
||||
return combined_state_dict
|
||||
|
Reference in New Issue
Block a user