[checkpointio] disable buffering

This commit is contained in:
ver217
2024-11-21 14:33:26 +08:00
parent cf519dac6a
commit 8fddbab04c
4 changed files with 11 additions and 5 deletions

View File

@@ -141,7 +141,9 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
from colossalai.utils.safetensors import save_nested
f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread")
f_writer = AsyncFileWriter(
fp=open(checkpoint, "wb", buffering=0), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
)
save_nested(f_writer, state_dict)
self.async_writers.append(f_writer)
else:
@@ -225,7 +227,9 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
from colossalai.utils.safetensors import save_nested
f_writer = AsyncFileWriter(
fp=open(checkpoint_file_path, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
fp=open(checkpoint_file_path, "wb", buffering=0),
n_entries=self.N_WRITE_ENTRIES,
backend="pthread",
)
save_nested(f_writer, shard)
self.async_writers.append(f_writer)