[checkpointio] support debug log (#6153)

* [checkpointio] support debug log

* [checkpointio] refactor async writer api

* fix test

* fix test
This commit is contained in:
Hongxin Liu
2024-12-02 11:29:19 +08:00
committed by GitHub
parent ab856fd308
commit 6280cb18b8
9 changed files with 33 additions and 54 deletions

View File

@@ -273,7 +273,6 @@ def async_save_state_dict_shards(
base_filename: str,
is_master: bool,
pinned_state_dict: Optional[Dict[str, torch.Tensor]],
n_write_entries: int,
use_pp_format: bool = False,
) -> Tuple[int, Dict[str, torch.Tensor], list]:
"""
@@ -290,7 +289,6 @@ def async_save_state_dict_shards(
Returns:
int: the total size of shards
"""
from tensornvme.async_file_io import AsyncFileWriter
total_size = 0
shard_filenames = []
@@ -311,9 +309,6 @@ def async_save_state_dict_shards(
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
writer = AsyncFileWriter(checkpoint_file_path, n_write_entries, backend="pthread")
writers.append(writer)
if pinned_state_dict is not None:
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in shard.keys()}
else:
@@ -321,7 +316,8 @@ def async_save_state_dict_shards(
returned_state_dict.update(sub_pinned_state_dict)
# Only save on master rank.
move_and_save(writer, shard, sub_pinned_state_dict)
writer = move_and_save(checkpoint_file_path, shard, sub_pinned_state_dict)
writers.append(writer)
shard_filenames.append(shard_file)
del shard