[checkpointio] General Checkpointing of Sharded Optimizers (#3984)

This commit is contained in:
Baizhou Zhang
2023-06-15 15:21:26 +08:00
committed by GitHub
parent 8bcad73677
commit c9cff7e7fa
8 changed files with 399 additions and 38 deletions

View File

@@ -1,17 +1,24 @@
# coding=utf-8
import re
from collections import abc as container_abcs
from collections import defaultdict
from itertools import chain
from pathlib import Path
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
import torch
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.tensor.d_tensor.d_tensor import DTensor
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
STATES_NAME = "pytorch_optim.bin"
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
STATES_INDEX_NAME = "pytorch_optim.bin.index.json"
GROUP_FILE_NAME = "pytorch_optim_group.bin"
# ======================================
# General helper functions
@@ -81,7 +88,7 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
# ======================================
# Helper functions for saving shard file
# ======================================
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
@@ -110,6 +117,50 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> It
yield current_block, current_block_size
def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits an optimizer state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
"""
# Only split state_dict['state']; state_dict['param_group'] is not considered in this function.
states = state_dict['state']
current_block = {}
current_block_size = 0
for param_id, state in states.items():
ret_block = None
ret_block_size = 0
# A state might contain more than one tensors.
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
state_size = 0
isDTensor = False
for state_tensor in state.values():
# If the states are stored as DTensors, mark isDTensor as true.
if type(state_tensor) == DTensor:
isDTensor = True
state_size += calculate_tensor_size(state_tensor)
if not isDTensor:
if current_block_size + state_size > max_shard_size:
ret_block = current_block
ret_block_size = current_block_size
current_block = {}
current_block_size = 0
current_block[param_id] = state
current_block_size += state_size
if ret_block != None:
yield ret_block, ret_block_size
yield current_block, current_block_size
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
"""
load shard state dict into model
@@ -179,6 +230,96 @@ def load_state_dict_into_model(model: nn.Module,
model.__class__.__name__, "\n\t".join(error_msgs)))
def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str) -> dict:
"""
Load information of param_groups into an initialized optimizer.
"""
# Load list of param_groups from given file path.
# The params in saved_groups are in the form of integer indices.
saved_groups = torch.load(param_group_path)
if not isinstance(saved_groups, List):
raise ValueError(f'The param_groups saved at {param_group_path} is not of List type')
# The params in param_groups are in the form of pytorch tensors.
# For more details, please view source code of Optimizer class in pytorch.
param_groups = optimizer.param_groups
# Check the compatibility of saved_groups and param_groups.
if len(param_groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of original parameter groups")
param_lens = (len(g['params']) for g in param_groups)
saved_lens = (len(g['params']) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError("loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group")
# Creating mapping from id to parameters.
id_map = {
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
)), chain.from_iterable((g['params'] for g in param_groups)))
}
# Update parameter groups, setting their 'params' value.
def update_group(group, new_group):
new_group['params'] = group['params']
return new_group
updated_groups = [update_group(g, ng) for g, ng in zip(param_groups, saved_groups)]
optimizer.__dict__.update({'param_groups': updated_groups})
return id_map
def load_states_into_optimizer(optimzier: Optimizer, state_dict: dict, id_map: dict):
r"""Copies states from `state_dict` into an Optimizer object.
Args:
optimizer(Optimizer): An initialized Optimizer object to be loaded
state_dict(dict): a mapping from tensor index (an integer)
to its states to be loaded (a mapping from state name to a tensor).
id_map(dict): a mapping from tensor index (an integer)
to its corresponding parameter (a tensor) whose states will be updated.
"""
def cast(param, value, key=None):
r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, torch.Tensor):
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
if (key != "step"):
if param.is_floating_point():
value = value.to(param.dtype)
value = value.to(param.device)
return value
elif isinstance(value, dict):
return {k: cast(param, v, key=k) for k, v in value.items()}
elif isinstance(value, container_abcs.Iterable):
return type(value)(cast(param, v) for v in value)
else:
return value
# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
new_states = defaultdict(dict)
for k, v in state_dict.items():
if k in id_map:
param = id_map[k]
new_states[param] = cast(param, v)
else:
new_states[k] = v
optimzier.state.update(new_states)
def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
# Do the cleaning up as in src code of Pytorch.
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
optimizer.defaults.setdefault('differentiable', False)
# ======================================
# Helper functions for saving state dict
# ======================================
@@ -203,6 +344,18 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
torch.save(state_dict, checkpoint_file_path)
def save_param_groups(state_dict: dict, group_file_path: str) -> None:
"""
Save information of param_groups to given file path.
Args:
state_dict (dict): state dict.
group_file_path (str): path to the group file.
"""
param_groups = state_dict["param_groups"]
torch.save(param_groups, group_file_path)
def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None:
"""
Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
@@ -392,28 +545,44 @@ def load_state_dict(checkpoint_file_path: Path):
return torch.load(checkpoint_file_path)
def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None and len(variant) > 0:
def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str:
if prefix is not None and len(prefix) > 0:
splits = weights_name.split(".")
splits = splits[:-1] + [variant] + splits[-1:]
splits = splits[:-1] + [prefix] + splits[-1:]
weights_name = ".".join(splits)
return weights_name
def get_base_filenames(variant: str = None, use_safetensors: bool = False):
def get_model_base_filenames(prefix: str = None, use_safetensors: bool = False):
"""
generate base weight filenames
generate base model weight filenames
"""
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
weights_name = add_variant(weights_name, variant)
weights_name = add_prefix(weights_name, prefix)
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
save_index_file = add_variant(save_index_file, variant)
save_index_file = add_prefix(save_index_file, prefix)
return weights_name, save_index_file
def get_optimizer_base_filenames(prefix: str = None):
"""
generate base optimizer state filenames
"""
states_name = STATES_NAME
states_name = add_prefix(states_name, prefix)
save_index_file = STATES_INDEX_NAME
save_index_file = add_prefix(save_index_file, prefix)
param_group_file = GROUP_FILE_NAME
param_group_file = add_prefix(param_group_file, prefix)
return states_name, save_index_file, param_group_file
def get_shard_filename(weights_name: str, idx: int):
"""
get shard file name