mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[checkpointio] support debug log (#6153)
* [checkpointio] support debug log * [checkpointio] refactor async writer api * fix test * fix test
This commit is contained in:
@@ -137,12 +137,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
state_dict = optimizer.state_dict(pinned_state_dicts, only_on_master=True)
|
||||
if self.coordinator.is_master():
|
||||
if use_async:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
|
||||
from colossalai.utils.safetensors import save_nested
|
||||
|
||||
f_writer = AsyncFileWriter(checkpoint, n_entries=self.N_WRITE_ENTRIES, backend="pthread")
|
||||
save_nested(f_writer, state_dict)
|
||||
f_writer = save_nested(checkpoint, state_dict)
|
||||
self.async_writers.append(f_writer)
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
@@ -222,16 +220,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||
if self.coordinator.is_master():
|
||||
if use_async:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
|
||||
from colossalai.utils.safetensors import save_nested
|
||||
|
||||
f_writer = AsyncFileWriter(
|
||||
checkpoint_file_path,
|
||||
n_entries=self.N_WRITE_ENTRIES,
|
||||
backend="pthread",
|
||||
)
|
||||
save_nested(f_writer, shard)
|
||||
f_writer = save_nested(checkpoint_file_path, shard)
|
||||
self.async_writers.append(f_writer)
|
||||
else:
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
|
||||
|
Reference in New Issue
Block a user