mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[checkpointio] fix zero optimizer async save memory (#6151)
* [checkpointio] fix zero optimizer async save memory * [checkpointio] fit new tensornvme api * [checkpointio] fit new tensornvme api
This commit is contained in:
@@ -128,22 +128,20 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
# the `state_dict` in LowLevelZeroOptimizer has communication
|
||||
# if only the master rank collect state_dict and save,
|
||||
# the communication on each rank would not match
|
||||
if use_async:
|
||||
if use_async and self.coordinator.is_master():
|
||||
if id(optimizer) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(optimizer)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
|
||||
else:
|
||||
pinned_state_dicts = None
|
||||
state_dict = optimizer.state_dict(pinned_state_dicts)
|
||||
state_dict = optimizer.state_dict(pinned_state_dicts, only_on_master=True)
|
||||
if self.coordinator.is_master():
|
||||
if use_async:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
|
||||
from colossalai.utils.safetensors import save_nested
|
||||
|
||||
f_writer = AsyncFileWriter(
|
||||
fp=open(checkpoint, "wb", buffering=0), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
|
||||
)
|
||||
f_writer = AsyncFileWriter(checkpoint, n_entries=self.N_WRITE_ENTRIES, backend="pthread")
|
||||
save_nested(f_writer, state_dict)
|
||||
self.async_writers.append(f_writer)
|
||||
else:
|
||||
@@ -192,13 +190,15 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
# state_dict only provide only 'param_groups'
|
||||
state_dict = optimizer.optim.state_dict()
|
||||
# state shard would be handled by the low-level zero optimizer
|
||||
if use_async:
|
||||
if use_async and self.coordinator.is_master():
|
||||
if id(optimizer) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(optimizer)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
|
||||
else:
|
||||
pinned_state_dicts = None
|
||||
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts)
|
||||
sharded_state = optimizer.state_dict_shard(
|
||||
max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts, only_on_master=True
|
||||
)
|
||||
|
||||
# Preparing file paths and index file.
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
|
||||
@@ -227,7 +227,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
from colossalai.utils.safetensors import save_nested
|
||||
|
||||
f_writer = AsyncFileWriter(
|
||||
fp=open(checkpoint_file_path, "wb", buffering=0),
|
||||
checkpoint_file_path,
|
||||
n_entries=self.N_WRITE_ENTRIES,
|
||||
backend="pthread",
|
||||
)
|
||||
|
Reference in New Issue
Block a user