mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-27 11:31:58 +00:00
Merge pull request #6149 from ver217/hotfix/ckpt
[checkpointio] disable buffering
This commit is contained in:
commit
8ecff0cb7f
@ -141,7 +141,9 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
|
||||
from colossalai.utils.safetensors import save_nested
|
||||
|
||||
f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread")
|
||||
f_writer = AsyncFileWriter(
|
||||
fp=open(checkpoint, "wb", buffering=0), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
|
||||
)
|
||||
save_nested(f_writer, state_dict)
|
||||
self.async_writers.append(f_writer)
|
||||
else:
|
||||
@ -225,7 +227,9 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
from colossalai.utils.safetensors import save_nested
|
||||
|
||||
f_writer = AsyncFileWriter(
|
||||
fp=open(checkpoint_file_path, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
|
||||
fp=open(checkpoint_file_path, "wb", buffering=0),
|
||||
n_entries=self.N_WRITE_ENTRIES,
|
||||
backend="pthread",
|
||||
)
|
||||
save_nested(f_writer, shard)
|
||||
self.async_writers.append(f_writer)
|
||||
|
@ -56,7 +56,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
if use_async:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
|
||||
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
|
||||
writer = AsyncFileWriter(open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread")
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||
self.async_writers.append(writer)
|
||||
|
@ -690,7 +690,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
from colossalai.utils.safetensors import move_and_save
|
||||
|
||||
writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
|
||||
writer = AsyncFileWriter(
|
||||
open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread"
|
||||
)
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||
self.async_writers.append(writer)
|
||||
|
@ -311,7 +311,7 @@ def async_save_state_dict_shards(
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||
|
||||
writer = AsyncFileWriter(open(checkpoint_file_path, "wb"), n_write_entries, backend="pthread")
|
||||
writer = AsyncFileWriter(open(checkpoint_file_path, "wb", buffering=0), n_write_entries, backend="pthread")
|
||||
writers.append(writer)
|
||||
|
||||
if pinned_state_dict is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user