mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[checkpointio] optimize zero optim checkpoint io (#4591)
* [zero] update checkpoint io to save memory * [checkpointio] add device map to save memory
This commit is contained in:
@@ -17,8 +17,13 @@ from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
get_optimizer_base_filenames,
|
||||
get_shard_filename,
|
||||
load_param_groups_into_optimizer,
|
||||
load_shard_state_dict,
|
||||
load_states_into_optimizer,
|
||||
save_param_groups,
|
||||
save_state_dict,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
unwrap_optimizer,
|
||||
)
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.utils import get_current_device
|
||||
@@ -126,19 +131,39 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
index_file_path (str): Path to the index file
|
||||
prefix (str): Not used.
|
||||
"""
|
||||
super().load_sharded_optimizer(optimizer, index_file_path, prefix)
|
||||
current_rank_state_dict = optimizer.optim.state_dict()['state']
|
||||
for param_idx, state in current_rank_state_dict.items():
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor) and k != 'step':
|
||||
padding_size = (self.coordinator.world_size -
|
||||
v.numel() % self.coordinator.world_size) % self.coordinator.world_size
|
||||
with torch.no_grad():
|
||||
v = v.flatten()
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
v_list = v.split(v.numel() // self.coordinator.world_size)
|
||||
current_rank_state_dict[param_idx][k] = v_list[self.coordinator.rank].detach()
|
||||
# If optimizer is wrapped, unwrap it.
|
||||
if isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = unwrap_optimizer(optimizer)
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
|
||||
|
||||
# Load param_groups
|
||||
param_group_path = ckpt_index_file.get_param_group_filename()
|
||||
if param_group_path is None:
|
||||
raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \
|
||||
Lacking param group file under current directory.')
|
||||
id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
|
||||
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||
|
||||
for shard_file in checkpoint_files:
|
||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
# shard state dict
|
||||
for param_idx, state in state_dict.items():
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor) and k != 'step':
|
||||
padding_size = (self.coordinator.world_size -
|
||||
v.numel() % self.coordinator.world_size) % self.coordinator.world_size
|
||||
with torch.no_grad():
|
||||
v = v.flatten()
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
v_list = v.split(v.numel() // self.coordinator.world_size)
|
||||
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
|
||||
load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer)
|
||||
|
||||
|
||||
class LowLevelZeroModel(ModelWrapper):
|
||||
|
Reference in New Issue
Block a user