[booster] gemini plugin support shard checkpoint (#3610)

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

---------

Co-authored-by: luchen <luchen@luchendeMBP.lan>
Co-authored-by: luchen <luchen@luchendeMacBook-Pro.local>
This commit is contained in:
jiangmingyan
2023-05-05 14:37:21 +08:00
committed by GitHub
parent 0f785cb1f3
commit 307894f74d
9 changed files with 268 additions and 96 deletions

View File

@@ -86,7 +86,7 @@ class CheckpointIO(ABC):
# the existence of index file means it is a sharded checkpoint
ckpt_path = Path(checkpoint)
index_file_exists, index_file_path = has_index_file(checkpoint)
# return the origin model instead of the unwrapped model
origin_model = model

View File

@@ -1,12 +1,12 @@
from pathlib import Path
from functools import reduce
import torch.nn as nn
from torch.optim import Optimizer
import logging
import os
import json
import gc
from typing import Optional
from typing import Optional, Iterator, OrderedDict, Tuple
from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
@@ -18,10 +18,9 @@ from .utils import (
shard_checkpoint,
load_shard_state_dict,
load_state_dict_into_model,
add_variant
get_shard_filename,
get_base_filenames
)
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME
__all__ = ['GeneralCheckpointIO']
@@ -85,30 +84,32 @@ class GeneralCheckpointIO(CheckpointIO):
# shard checkpoint
state_dict = model.state_dict()
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
weights_name = add_variant(weights_name, variant)
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
# Save the model
for shard_file, shard in shards.items():
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
total_size = 0
index_file = CheckpointIndexFile(checkpoint_path)
for idx, shard_pair in enumerate(state_dict_shard):
shard = shard_pair[0]
shard_file = get_shard_filename(weights_name, idx)
total_size = total_size + shard_pair[1]
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors)
# save index file
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(checkpoint_path, add_variant(save_index_file, variant))
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
f"The model is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False):
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False,
use_safetensors: bool = False, load_sub_module: bool = True):
"""
load shard model, load model from multiple files
"""
@@ -122,17 +123,21 @@ class GeneralCheckpointIO(CheckpointIO):
# read checkpoint index file
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames()
missing_keys = ckpt_index_file.get_all_param_names()
missing_keys = []
for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
load_state_dict_into_model(model, state_dict, missing_keys, strict)
load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)
del state_dict
gc.collect()
if strict and len(missing_keys) > 0:
error_msgs = 'Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in missing_keys))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))
if strict:
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
if len(remain_keys) > 0:
error_msgs = 'Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in missing_keys))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))

View File

@@ -1,6 +1,8 @@
import json
from pathlib import Path
from typing import Any, List, Union
import os
import json
from .utils import is_dtensor_checkpoint
@@ -18,8 +20,8 @@ class CheckpointIndexFile:
>>> index.export('new_index.json')
"""
def __init__(self) -> None:
self.root_path = None
def __init__(self, root_path=None) -> None:
self.root_path = root_path
self.metadata: dict = dict()
self.weight_map: dict = dict()
@@ -154,3 +156,13 @@ class CheckpointIndexFile:
Get all the weight keys.
"""
return list(self.weight_map.keys())
def write_index_file(self, save_index_file):
"""
Wriete index file.
"""
save_index_file = os.path.join(self.root_path, save_index_file)
index = {"metadata": self.metadata, "weight_map": self.weight_map}
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)

View File

@@ -2,7 +2,7 @@
from pathlib import Path
import torch
import torch.nn as nn
from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple
from typing import List, Mapping, OrderedDict, Optional, Tuple, Iterator
from colossalai.tensor.d_tensor.d_tensor import DTensor
import re
@@ -77,55 +77,35 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
# ======================================
# Helper functions for saving shard file
# ======================================
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, weights_name: str = WEIGHTS_NAME):
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
"""
sharded_state_dicts = []
current_block = {}
current_block_size = 0
total_size = 0
for key, weight in state_dict.items():
ret_block = None
ret_block_size = 0
if type(weight) != DTensor:
weight_size = calculate_tensor_size(weight)
# If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size:
sharded_state_dicts.append(current_block)
ret_block = current_block
ret_block_size = current_block_size
current_block = {}
current_block_size = 0
current_block[key] = weight
current_block_size += weight_size
total_size += weight_size
if ret_block != None:
yield ret_block, ret_block_size
# Add the last block
sharded_state_dicts.append(current_block)
yield current_block, current_block_size
# If we only have one shard, we return it
if len(sharded_state_dicts) == 1:
return {weights_name: sharded_state_dicts[0]}, None
# Otherwise, let's build the index
weight_map = {}
shards = {}
for idx, shard in enumerate(sharded_state_dicts):
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
shard_file = shard_file.replace(
".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
)
shards[shard_file] = shard
for key in shard.keys():
weight_map[key] = shard_file
# Add the metadata
metadata = {"total_size": total_size}
index = {"metadata": metadata, "weight_map": weight_map}
return shards, index
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
"""
@@ -146,7 +126,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
else:
return torch.load(checkpoint_file)
def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False):
def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants.
@@ -167,29 +147,22 @@ def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missi
if metadata is not None:
state_dict._metadata = metadata
def load(module: nn.Module, state_dict, prefix=""):
def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
args = (state_dict, prefix, local_metadata, True, sub_missing_keys, [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
module._load_from_state_dict(*args)
if load_sub_module:
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
load(model, state_dict, "")
load(model, state_dict, "", load_sub_module)
del load
# deal with missing key
if len(missing_keys) > 0:
deleted_keys = []
for key in missing_keys:
if key not in sub_missing_keys:
deleted_keys.append(key)
for key in deleted_keys:
missing_keys.remove(key)
missing_keys = missing_keys.append(sub_missing_keys)
if strict:
if len(unexpected_keys) > 0:
@@ -417,3 +390,24 @@ def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
weights_name = ".".join(splits)
return weights_name
def get_base_filenames(variant: str=None, use_safetensors: bool=False):
"""
generate base weight filenames
"""
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
weights_name = add_variant(weights_name, variant)
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
save_index_file = add_variant(save_index_file, variant)
return weights_name, save_index_file
def get_shard_filename(weights_name: str, idx: int):
"""
get shard file name
"""
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors")
return shard_file