mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[checkpointio] support async model save (#6131)
* [checkpointio] support async model save * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
|
||||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES
|
||||
@@ -27,10 +27,11 @@ class PreparedData:
|
||||
offset: int
|
||||
|
||||
|
||||
def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor]]:
|
||||
def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor], List[str]]:
|
||||
sorted_data = sorted(data.items(), key=lambda x: (x[1].dtype, x[0]))
|
||||
|
||||
tensors = []
|
||||
tensor_keys = []
|
||||
metadata = {}
|
||||
offset = 0
|
||||
|
||||
@@ -42,6 +43,7 @@ def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Ten
|
||||
offset += n
|
||||
metadata[name] = asdict(tensor_info)
|
||||
tensors.append(tensor)
|
||||
tensor_keys.append(name)
|
||||
|
||||
metadata_buf = json.dumps(metadata).encode("utf-8")
|
||||
|
||||
@@ -50,11 +52,11 @@ def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Ten
|
||||
|
||||
n = len(metadata_buf)
|
||||
|
||||
return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors
|
||||
return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors, tensor_keys
|
||||
|
||||
|
||||
def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
|
||||
prepared_data, tensors = prepare(state_dict)
|
||||
prepared_data, tensors, _ = prepare(state_dict)
|
||||
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
|
||||
|
||||
f_writer.write(n.to_bytes(8, byteorder="little"))
|
||||
@@ -62,3 +64,22 @@ def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None
|
||||
|
||||
for tensor in tensors:
|
||||
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
|
||||
|
||||
|
||||
def move_and_save(
|
||||
f_writer: AsyncFileWriter,
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> None:
|
||||
prepared_data, _, tensor_keys = prepare(state_dict)
|
||||
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
|
||||
|
||||
f_writer.write(n.to_bytes(8, byteorder="little"))
|
||||
f_writer.write(header_bytes)
|
||||
|
||||
f_writer.register_h2d(len(tensor_keys))
|
||||
for name in tensor_keys:
|
||||
if state_dict_pinned:
|
||||
f_writer.write_tensor(state_dict[name], state_dict_pinned[name])
|
||||
else:
|
||||
f_writer.write_tensor(state_dict[name])
|
||||
|
Reference in New Issue
Block a user