[checkpoint] support huggingface style sharded checkpoint (#3461)

* [checkpoint] support huggingface style sharded checkpoint

* [checkpoint] support huggingface style sharded checkpoint

* [checkpoint] support huggingface style sharded checkpoint

* [checkpoint] support huggingface style sharded checkpoint

* [checkpoint] support huggingface style sharded checkpoint

---------

Co-authored-by: luchen <luchen@luchendeMBP.lan>
This commit is contained in:
jiangmingyan
2023-04-06 16:23:39 +08:00
committed by GitHub
parent 6afeb1202a
commit 52a933e175
4 changed files with 291 additions and 45 deletions

View File

@@ -1,13 +1,19 @@
# coding=utf-8
from pathlib import Path
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple
from colossalai.tensor.d_tensor.d_tensor import DTensor
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "model.bin"
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
WEIGHTS_INDEX_NAME = "model.bin.index.json"
# ======================================
# General helper functions
# ======================================
def calculate_tensor_size(tensor: torch.Tensor) -> float:
"""
Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size.
@@ -68,6 +74,130 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
return False
# ======================================
# Helper functions for saving shard file
# ======================================
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, weights_name: str = WEIGHTS_NAME):
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
"""
sharded_state_dicts = []
current_block = {}
current_block_size = 0
total_size = 0
for key, weight in state_dict.items():
if type(weight) != DTensor:
weight_size = calculate_tensor_size(weight)
# If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size:
sharded_state_dicts.append(current_block)
current_block = {}
current_block_size = 0
current_block[key] = weight
current_block_size += weight_size
total_size += weight_size
# Add the last block
sharded_state_dicts.append(current_block)
# If we only have one shard, we return it
if len(sharded_state_dicts) == 1:
return {weights_name: sharded_state_dicts[0]}, None
# Otherwise, let's build the index
weight_map = {}
shards = {}
for idx, shard in enumerate(sharded_state_dicts):
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
shard_file = shard_file.replace(
".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
)
shards[shard_file] = shard
for key in shard.keys():
weight_map[key] = shard_file
# Add the metadata
metadata = {"total_size": total_size}
index = {"metadata": metadata, "weight_map": weight_map}
return shards, index
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
"""
load shard state dict into model
"""
if use_safetensors and not checkpoint_file.suffix == ".safetensors":
raise Exception("load the model using `safetensors`, but no file endwith .safetensors")
if use_safetensors:
from safetensors.torch import safe_open
from safetensors.torch import load_file as safe_load_file
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata["format"] != "pt":
raise NotImplementedError(
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
)
return safe_load_file(checkpoint_file)
else:
return torch.load(checkpoint_file)
def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
"""
if not isinstance(state_dict, Mapping):
raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))
unexpected_keys: List[str] = []
sub_missing_keys: List[str] = []
error_msgs: List[str] = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = OrderedDict(state_dict)
if metadata is not None:
state_dict._metadata = metadata
def load(module: nn.Module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
load(model, state_dict, "")
del load
# deal with missing key
if len(missing_keys) > 0:
deleted_keys = []
for key in missing_keys:
if key not in sub_missing_keys:
deleted_keys.append(key)
for key in deleted_keys:
missing_keys.remove(key)
if strict:
if len(unexpected_keys) > 0:
error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in unexpected_keys))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
# ======================================
# Helper functions for saving state dict
# ======================================
@@ -86,8 +216,8 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
assert is_safetensors_available(), "safetensors is not available."
assert checkpoint_file_path.endswith('.safetensors'), \
"safetensors only supports .safetensors suffix for checkpoint file."
from safetensors.torch import save_file
save_file(state_dict, checkpoint_file_path)
from safetensors.torch import save_file as safe_save_file
safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"})
else:
torch.save(state_dict, checkpoint_file_path)