mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-20 09:01:06 +00:00
[checkpointio] support debug log (#6153)
* [checkpointio] support debug log * [checkpointio] refactor async writer api * fix test * fix test
This commit is contained in:
@@ -54,13 +54,11 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
pass
|
||||
|
||||
if use_async:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
|
||||
writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread")
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
|
||||
self.async_writers.append(writer)
|
||||
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
|
||||
|
||||
else:
|
||||
# save the checkpoint
|
||||
@@ -196,7 +194,6 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
base_filename=weights_name,
|
||||
is_master=True,
|
||||
pinned_state_dict=pinned_state_dict,
|
||||
n_write_entries=self.N_WRITE_ENTRIES,
|
||||
)
|
||||
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
|
||||
self.async_writers.extend(writers)
|
||||
|
Reference in New Issue
Block a user