[checkpointio] General Checkpointing of Sharded Optimizers (#3984)

This commit is contained in:
Baizhou Zhang
2023-06-15 15:21:26 +08:00
committed by GitHub
parent 8bcad73677
commit c9cff7e7fa
8 changed files with 399 additions and 38 deletions

View File

@@ -11,15 +11,21 @@ from torch.optim import Optimizer
from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
get_base_filenames,
get_model_base_filenames,
get_optimizer_base_filenames,
get_shard_filename,
has_index_file,
is_safetensors_available,
load_param_groups_into_optimizer,
load_shard_state_dict,
load_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
save_param_groups,
save_state_dict,
shard_checkpoint,
shard_model_checkpoint,
shard_optimizer_checkpoint,
sharded_optimizer_loading_epilogue,
)
__all__ = ['GeneralCheckpointIO']
@@ -44,12 +50,30 @@ class GeneralCheckpointIO(CheckpointIO):
# save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors)
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
"""
Load sharded optimizer with the given path to index file.
"""
optimizer.load_state_dict
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
checkpoint = load_state_dict(checkpoint)
optimizer.load_state_dict(checkpoint)
# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \
Lacking param group file under current directory.')
id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
load_states_into_optimizer(optimizer, state_dict, id_map)
del state_dict
gc.collect()
sharded_optimizer_loading_epilogue(optimizer)
def save_sharded_optimizer(
self,
@@ -59,7 +83,54 @@ class GeneralCheckpointIO(CheckpointIO):
prefix: str,
size_per_shard: int,
):
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
"""
Save sharded optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
- A group file (pytorch_optim_group.bin) recording information of param_groups
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
"""
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Offload optimizer states. States are broken into shards within max_shard_size.
state_dict = optimizer.state_dict()
sharded_state = shard_optimizer_checkpoint(state_dict, max_shard_size=size_per_shard)
# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
# Store the information of param groups to param_group_file.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(state_dict, group_file_path)
# Save shards of optimizer states.
total_size = 0
for idx, shard_pair in enumerate(sharded_state):
shard, current_size = shard_pair
shard_file = get_shard_filename(states_name, idx)
total_size = total_size + current_size
for param_id in shard.keys():
index_file.append_weight_map(str(param_id), shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
# Wrap up index file.
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
checkpoint = load_state_dict(checkpoint)
optimizer.load_state_dict(checkpoint)
def save_unsharded_optimizer(
self,
@@ -74,7 +145,7 @@ class GeneralCheckpointIO(CheckpointIO):
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = False,
variant: Optional[str] = None,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
"""
@@ -89,9 +160,9 @@ class GeneralCheckpointIO(CheckpointIO):
# shard checkpoint
state_dict = model.state_dict()
state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
state_dict_shard = shard_model_checkpoint(state_dict, max_shard_size=max_shard_size)
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
total_size = 0
index_file = CheckpointIndexFile(checkpoint_path)
for idx, shard_pair in enumerate(state_dict_shard):
@@ -128,7 +199,7 @@ class GeneralCheckpointIO(CheckpointIO):
# read checkpoint index file
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames()
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
missing_keys = []
for shard_file in checkpoint_files: