[checkpointio] support debug log (#6153)

* [checkpointio] support debug log

* [checkpointio] refactor async writer api

* fix test

* fix test
This commit is contained in:
Hongxin Liu
2024-12-02 11:29:19 +08:00
committed by GitHub
parent ab856fd308
commit 6280cb18b8
9 changed files with 33 additions and 54 deletions

View File

@@ -137,12 +137,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
state_dict = optimizer.state_dict(pinned_state_dicts, only_on_master=True)
if self.coordinator.is_master():
if use_async:
from tensornvme.async_file_io import AsyncFileWriter
from colossalai.utils.safetensors import save_nested
f_writer = AsyncFileWriter(checkpoint, n_entries=self.N_WRITE_ENTRIES, backend="pthread")
save_nested(f_writer, state_dict)
f_writer = save_nested(checkpoint, state_dict)
self.async_writers.append(f_writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)
@@ -222,16 +220,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
checkpoint_file_path = os.path.join(checkpoint, shard_file)
if self.coordinator.is_master():
if use_async:
from tensornvme.async_file_io import AsyncFileWriter
from colossalai.utils.safetensors import save_nested
f_writer = AsyncFileWriter(
checkpoint_file_path,
n_entries=self.N_WRITE_ENTRIES,
backend="pthread",
)
save_nested(f_writer, shard)
f_writer = save_nested(checkpoint_file_path, shard)
self.async_writers.append(f_writer)
else:
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)