mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[checkpointio] General Checkpointing of Sharded Optimizers (#3984)
This commit is contained in:
@@ -11,15 +11,21 @@ from torch.optim import Optimizer
|
||||
from .checkpoint_io_base import CheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
from .utils import (
|
||||
get_base_filenames,
|
||||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
get_shard_filename,
|
||||
has_index_file,
|
||||
is_safetensors_available,
|
||||
load_param_groups_into_optimizer,
|
||||
load_shard_state_dict,
|
||||
load_state_dict,
|
||||
load_state_dict_into_model,
|
||||
load_states_into_optimizer,
|
||||
save_param_groups,
|
||||
save_state_dict,
|
||||
shard_checkpoint,
|
||||
shard_model_checkpoint,
|
||||
shard_optimizer_checkpoint,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
)
|
||||
|
||||
__all__ = ['GeneralCheckpointIO']
|
||||
@@ -44,12 +50,30 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
# save the checkpoint
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
|
||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
|
||||
"""
|
||||
Load sharded optimizer with the given path to index file.
|
||||
"""
|
||||
optimizer.load_state_dict
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
checkpoint = load_state_dict(checkpoint)
|
||||
optimizer.load_state_dict(checkpoint)
|
||||
# Load param_groups
|
||||
param_group_path = ckpt_index_file.get_param_group_filename()
|
||||
if param_group_path is None:
|
||||
raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \
|
||||
Lacking param group file under current directory.')
|
||||
id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
|
||||
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||
|
||||
for shard_file in checkpoint_files:
|
||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||
del state_dict
|
||||
gc.collect()
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer)
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self,
|
||||
@@ -59,7 +83,54 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
prefix: str,
|
||||
size_per_shard: int,
|
||||
):
|
||||
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
|
||||
"""
|
||||
Save sharded optimizer checkpoint under the given checkpointing path.
|
||||
The following files will be created under the path:
|
||||
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
|
||||
- A group file (pytorch_optim_group.bin) recording information of param_groups
|
||||
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
|
||||
"""
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Offload optimizer states. States are broken into shards within max_shard_size.
|
||||
state_dict = optimizer.state_dict()
|
||||
sharded_state = shard_optimizer_checkpoint(state_dict, max_shard_size=size_per_shard)
|
||||
|
||||
# Preparing file paths and index file.
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
|
||||
# Store the information of param groups to param_group_file.
|
||||
index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
save_param_groups(state_dict, group_file_path)
|
||||
|
||||
# Save shards of optimizer states.
|
||||
total_size = 0
|
||||
for idx, shard_pair in enumerate(sharded_state):
|
||||
shard, current_size = shard_pair
|
||||
shard_file = get_shard_filename(states_name, idx)
|
||||
total_size = total_size + current_size
|
||||
for param_id in shard.keys():
|
||||
index_file.append_weight_map(str(param_id), shard_file)
|
||||
|
||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
|
||||
|
||||
# Wrap up index file.
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
logging.info(f"The optimizer is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
checkpoint = load_state_dict(checkpoint)
|
||||
optimizer.load_state_dict(checkpoint)
|
||||
|
||||
def save_unsharded_optimizer(
|
||||
self,
|
||||
@@ -74,7 +145,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
model: nn.Module,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = False,
|
||||
variant: Optional[str] = None,
|
||||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
"""
|
||||
@@ -89,9 +160,9 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
|
||||
# shard checkpoint
|
||||
state_dict = model.state_dict()
|
||||
state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
|
||||
state_dict_shard = shard_model_checkpoint(state_dict, max_shard_size=max_shard_size)
|
||||
|
||||
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
total_size = 0
|
||||
index_file = CheckpointIndexFile(checkpoint_path)
|
||||
for idx, shard_pair in enumerate(state_dict_shard):
|
||||
@@ -128,7 +199,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
|
||||
# read checkpoint index file
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames()
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||
missing_keys = []
|
||||
|
||||
for shard_file in checkpoint_files:
|
||||
|
Reference in New Issue
Block a user