mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[checkpointio] fix zero optimizer async save memory (#6151)
* [checkpointio] fix zero optimizer async save memory * [checkpointio] fit new tensornvme api * [checkpointio] fit new tensornvme api
This commit is contained in:
@@ -72,7 +72,6 @@ class CheckpointIO(ABC):
|
||||
def _sync_io(self):
|
||||
for writer in self.async_writers:
|
||||
writer.synchronize()
|
||||
writer.fp.close()
|
||||
self.async_writers.clear()
|
||||
|
||||
def _sync_d2h(self):
|
||||
|
@@ -56,7 +56,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
if use_async:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
|
||||
writer = AsyncFileWriter(open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread")
|
||||
writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread")
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||
self.async_writers.append(writer)
|
||||
|
@@ -690,9 +690,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
from colossalai.utils.safetensors import move_and_save
|
||||
|
||||
writer = AsyncFileWriter(
|
||||
open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread"
|
||||
)
|
||||
writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread")
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||
self.async_writers.append(writer)
|
||||
|
@@ -311,7 +311,7 @@ def async_save_state_dict_shards(
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||
|
||||
writer = AsyncFileWriter(open(checkpoint_file_path, "wb", buffering=0), n_write_entries, backend="pthread")
|
||||
writer = AsyncFileWriter(checkpoint_file_path, n_write_entries, backend="pthread")
|
||||
writers.append(writer)
|
||||
|
||||
if pinned_state_dict is not None:
|
||||
|
Reference in New Issue
Block a user