[checkpointio] support debug log (#6153)

* [checkpointio] support debug log

* [checkpointio] refactor async writer api

* fix test

* fix test
This commit is contained in:
Hongxin Liu
2024-12-02 11:29:19 +08:00
committed by GitHub
parent ab856fd308
commit 6280cb18b8
9 changed files with 33 additions and 54 deletions

View File

@@ -686,15 +686,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
for _state_dict in state_dict_list:
complete_state_dict.update(_state_dict)
if use_async:
from tensornvme.async_file_io import AsyncFileWriter
from colossalai.utils.safetensors import move_and_save
writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
self.async_writers.append(writer)
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
else:
save_state_dict(complete_state_dict, checkpoint, use_safetensors)