[checkpointio] disable buffering

This commit is contained in:
ver217
2024-11-21 14:33:26 +08:00
parent cf519dac6a
commit 8fddbab04c
4 changed files with 11 additions and 5 deletions

View File

@@ -690,7 +690,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
from colossalai.utils.safetensors import move_and_save
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
writer = AsyncFileWriter(
open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread"
)
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer)