[checkpointio] support debug log (#6153)

* [checkpointio] support debug log

* [checkpointio] refactor async writer api

* fix test

* fix test
This commit is contained in:
Hongxin Liu
2024-12-02 11:29:19 +08:00
committed by GitHub
parent ab856fd308
commit 6280cb18b8
9 changed files with 33 additions and 54 deletions

View File

@@ -15,6 +15,8 @@ import io
from torch.distributed.distributed_c10d import _pickler, _unpickler
ASYNC_WRITE_ENTRIES = 32
def _object_to_tensor(obj, device):
f = io.BytesIO()
@@ -149,32 +151,31 @@ def prepare(
return PreparedData(n=n, header_bytes=header_buf, offset=offset), tensors, tensor_keys
def save(
f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
) -> None:
def save(path: str, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None) -> None:
prepared_data, tensors, _ = prepare(state_dict, metadata)
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
f_writer = AsyncFileWriter(path, n_entries=ASYNC_WRITE_ENTRIES, backend="pthread", n_tasks=2 + len(tensors))
f_writer.write(n.to_bytes(8, byteorder="little"))
f_writer.write(header_bytes)
for tensor in tensors:
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
return f_writer
def save_nested(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
def save_nested(path: str, state_dict: Dict[str, torch.Tensor]) -> None:
flatten_data, metadata = _flatten_optim_state_dict(state_dict)
save(f_writer, flatten_data, metadata)
return save(path, flatten_data, metadata)
def move_and_save(
f_writer: AsyncFileWriter,
path: str,
state_dict: Dict[str, torch.Tensor],
state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None,
) -> None:
prepared_data, _, tensor_keys = prepare(state_dict)
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
f_writer = AsyncFileWriter(path, n_entries=ASYNC_WRITE_ENTRIES, backend="pthread", n_tasks=2 + len(tensor_keys))
f_writer.write(n.to_bytes(8, byteorder="little"))
f_writer.write(header_bytes)
@@ -184,6 +185,7 @@ def move_and_save(
f_writer.write_tensor(state_dict[name], state_dict_pinned[name])
else:
f_writer.write_tensor(state_dict[name])
return f_writer
def load_flat(checkpoint_path):