[checkpointio] support async model save (#6131)

* [checkpointio] support async model save

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2024-11-14 11:38:10 +08:00
parent 5a03d2696d
commit d4a436051d
7 changed files with 209 additions and 28 deletions

View File

@@ -1,7 +1,7 @@
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
import json
from dataclasses import asdict, dataclass
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import torch
from safetensors.torch import _TYPES
@@ -27,10 +27,11 @@ class PreparedData:
offset: int
def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor]]:
def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor], List[str]]:
sorted_data = sorted(data.items(), key=lambda x: (x[1].dtype, x[0]))
tensors = []
tensor_keys = []
metadata = {}
offset = 0
@@ -42,6 +43,7 @@ def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Ten
offset += n
metadata[name] = asdict(tensor_info)
tensors.append(tensor)
tensor_keys.append(name)
metadata_buf = json.dumps(metadata).encode("utf-8")
@@ -50,11 +52,11 @@ def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Ten
n = len(metadata_buf)
return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors
return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors, tensor_keys
def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
prepared_data, tensors = prepare(state_dict)
prepared_data, tensors, _ = prepare(state_dict)
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
f_writer.write(n.to_bytes(8, byteorder="little"))
@@ -62,3 +64,22 @@ def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None
for tensor in tensors:
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
def move_and_save(
f_writer: AsyncFileWriter,
state_dict: Dict[str, torch.Tensor],
state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None,
) -> None:
prepared_data, _, tensor_keys = prepare(state_dict)
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
f_writer.write(n.to_bytes(8, byteorder="little"))
f_writer.write(header_bytes)
f_writer.register_h2d(len(tensor_keys))
for name in tensor_keys:
if state_dict_pinned:
f_writer.write_tensor(state_dict[name], state_dict_pinned[name])
else:
f_writer.write_tensor(state_dict[name])