[checkpointio] support debug log (#6153)

* [checkpointio] support debug log

* [checkpointio] refactor async writer api

* fix test

* fix test
This commit is contained in:
Hongxin Liu
2024-12-02 11:29:19 +08:00
committed by GitHub
parent ab856fd308
commit 6280cb18b8
9 changed files with 33 additions and 54 deletions

View File

@@ -54,13 +54,11 @@ class GeneralCheckpointIO(CheckpointIO):
pass
if use_async:
from tensornvme.async_file_io import AsyncFileWriter
writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
self.async_writers.append(writer)
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
else:
# save the checkpoint
@@ -196,7 +194,6 @@ class GeneralCheckpointIO(CheckpointIO):
base_filename=weights_name,
is_master=True,
pinned_state_dict=pinned_state_dict,
n_write_entries=self.N_WRITE_ENTRIES,
)
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
self.async_writers.extend(writers)