[ckpt] Add async ckpt api (#6136)

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix
This commit is contained in:
Wang Binluo
2024-11-15 18:19:16 +08:00
committed by Hongxin Liu
parent d4a436051d
commit 8e08c27e19
12 changed files with 174 additions and 86 deletions

View File

@@ -28,14 +28,12 @@ class PreparedData:
def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor], List[str]]:
sorted_data = sorted(data.items(), key=lambda x: (x[1].dtype, x[0]))
tensors = []
tensor_keys = []
metadata = {}
offset = 0
for name, tensor in sorted_data:
for name, tensor in data.items():
n = tensor.numel() * tensor.element_size()
tensor_info = TensorInfo(
dtype=_TYPES_INV[tensor.dtype], shape=list(tensor.shape), data_offsets=(offset, offset + n)