diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py new file mode 100644 index 000000000..4e295cdfc --- /dev/null +++ b/colossalai/utils/safetensors.py @@ -0,0 +1,49 @@ +# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214 +import json +from dataclasses import asdict, dataclass +from typing import Dict, List, Tuple + +import torch +from safetensors.torch import _TYPES + +_TYPES_INV = {v: k for k, v in _TYPES.items()} + + +@dataclass +class TensorInfo: + dtype: str + shape: List[int] + data_offsets: Tuple[int, int] + + +@dataclass +class PreparedData: + n: int + header_bytes: bytes + offset: int + + +def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor]]: + sorted_data = sorted(data.items(), key=lambda x: (x[1].dtype, x[0])) + + tensors = [] + metadata = {} + offset = 0 + + for name, tensor in sorted_data: + n = tensor.numel() * tensor.element_size() + tensor_info = TensorInfo( + dtype=_TYPES_INV[tensor.dtype], shape=list(tensor.shape), data_offsets=(offset, offset + n) + ) + offset += n + metadata[name] = asdict(tensor_info) + tensors.append(tensor) + + metadata_buf = json.dumps(metadata).encode("utf-8") + + extra = (8 - len(metadata_buf) % 8) % 8 + metadata_buf += b" " * extra + + n = len(metadata_buf) + + return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors