mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-05-05 12:24:38 +00:00
[booster] gemini plugin support shard checkpoint (#3610)
* gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint --------- Co-authored-by: luchen <luchen@luchendeMBP.lan> Co-authored-by: luchen <luchen@luchendeMacBook-Pro.local>
This commit is contained in:
@@ -86,7 +86,7 @@ class CheckpointIO(ABC):
|
||||
# the existence of index file means it is a sharded checkpoint
|
||||
ckpt_path = Path(checkpoint)
|
||||
index_file_exists, index_file_path = has_index_file(checkpoint)
|
||||
|
||||
|
||||
# return the origin model instead of the unwrapped model
|
||||
origin_model = model
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from pathlib import Path
|
||||
from functools import reduce
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import gc
|
||||
from typing import Optional
|
||||
from typing import Optional, Iterator, OrderedDict, Tuple
|
||||
|
||||
from .checkpoint_io_base import CheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
@@ -18,10 +18,9 @@ from .utils import (
|
||||
shard_checkpoint,
|
||||
load_shard_state_dict,
|
||||
load_state_dict_into_model,
|
||||
add_variant
|
||||
get_shard_filename,
|
||||
get_base_filenames
|
||||
)
|
||||
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME
|
||||
|
||||
|
||||
__all__ = ['GeneralCheckpointIO']
|
||||
|
||||
@@ -85,30 +84,32 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
|
||||
# shard checkpoint
|
||||
state_dict = model.state_dict()
|
||||
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
|
||||
weights_name = add_variant(weights_name, variant)
|
||||
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
|
||||
state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
|
||||
|
||||
# Save the model
|
||||
for shard_file, shard in shards.items():
|
||||
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
|
||||
total_size = 0
|
||||
index_file = CheckpointIndexFile(checkpoint_path)
|
||||
for idx, shard_pair in enumerate(state_dict_shard):
|
||||
shard = shard_pair[0]
|
||||
shard_file = get_shard_filename(weights_name, idx)
|
||||
total_size = total_size + shard_pair[1]
|
||||
for key in shard.keys():
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
|
||||
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors)
|
||||
|
||||
# save index file
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
|
||||
|
||||
save_index_file = os.path.join(checkpoint_path, add_variant(save_index_file, variant))
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
logging.info(
|
||||
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
||||
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
|
||||
f"The model is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False):
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False,
|
||||
use_safetensors: bool = False, load_sub_module: bool = True):
|
||||
"""
|
||||
load shard model, load model from multiple files
|
||||
"""
|
||||
@@ -122,17 +123,21 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
# read checkpoint index file
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames()
|
||||
missing_keys = ckpt_index_file.get_all_param_names()
|
||||
missing_keys = []
|
||||
|
||||
for shard_file in checkpoint_files:
|
||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
|
||||
load_state_dict_into_model(model, state_dict, missing_keys, strict)
|
||||
load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)
|
||||
del state_dict
|
||||
gc.collect()
|
||||
|
||||
if strict and len(missing_keys) > 0:
|
||||
error_msgs = 'Missing key(s) in state_dict: {}. '.format(
|
||||
', '.join('"{}"'.format(k) for k in missing_keys))
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
if strict:
|
||||
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
|
||||
if len(remain_keys) > 0:
|
||||
error_msgs = 'Missing key(s) in state_dict: {}. '.format(
|
||||
', '.join('"{}"'.format(k) for k in missing_keys))
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Union
|
||||
import os
|
||||
import json
|
||||
|
||||
from .utils import is_dtensor_checkpoint
|
||||
|
||||
@@ -18,8 +20,8 @@ class CheckpointIndexFile:
|
||||
>>> index.export('new_index.json')
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.root_path = None
|
||||
def __init__(self, root_path=None) -> None:
|
||||
self.root_path = root_path
|
||||
self.metadata: dict = dict()
|
||||
self.weight_map: dict = dict()
|
||||
|
||||
@@ -154,3 +156,13 @@ class CheckpointIndexFile:
|
||||
Get all the weight keys.
|
||||
"""
|
||||
return list(self.weight_map.keys())
|
||||
|
||||
def write_index_file(self, save_index_file):
|
||||
"""
|
||||
Wriete index file.
|
||||
"""
|
||||
save_index_file = os.path.join(self.root_path, save_index_file)
|
||||
index = {"metadata": self.metadata, "weight_map": self.weight_map}
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple
|
||||
from typing import List, Mapping, OrderedDict, Optional, Tuple, Iterator
|
||||
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
||||
import re
|
||||
|
||||
@@ -77,55 +77,35 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
|
||||
# ======================================
|
||||
# Helper functions for saving shard file
|
||||
# ======================================
|
||||
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, weights_name: str = WEIGHTS_NAME):
|
||||
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
|
||||
"""
|
||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
given size.
|
||||
"""
|
||||
sharded_state_dicts = []
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
total_size = 0
|
||||
|
||||
for key, weight in state_dict.items():
|
||||
ret_block = None
|
||||
ret_block_size = 0
|
||||
if type(weight) != DTensor:
|
||||
weight_size = calculate_tensor_size(weight)
|
||||
|
||||
# If this weight is going to tip up over the maximal size, we split.
|
||||
if current_block_size + weight_size > max_shard_size:
|
||||
sharded_state_dicts.append(current_block)
|
||||
ret_block = current_block
|
||||
ret_block_size = current_block_size
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
|
||||
current_block[key] = weight
|
||||
current_block_size += weight_size
|
||||
total_size += weight_size
|
||||
|
||||
if ret_block != None:
|
||||
yield ret_block, ret_block_size
|
||||
|
||||
# Add the last block
|
||||
sharded_state_dicts.append(current_block)
|
||||
yield current_block, current_block_size
|
||||
|
||||
# If we only have one shard, we return it
|
||||
if len(sharded_state_dicts) == 1:
|
||||
return {weights_name: sharded_state_dicts[0]}, None
|
||||
|
||||
# Otherwise, let's build the index
|
||||
weight_map = {}
|
||||
shards = {}
|
||||
|
||||
for idx, shard in enumerate(sharded_state_dicts):
|
||||
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
|
||||
shard_file = shard_file.replace(
|
||||
".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
|
||||
)
|
||||
shards[shard_file] = shard
|
||||
for key in shard.keys():
|
||||
weight_map[key] = shard_file
|
||||
|
||||
# Add the metadata
|
||||
metadata = {"total_size": total_size}
|
||||
index = {"metadata": metadata, "weight_map": weight_map}
|
||||
return shards, index
|
||||
|
||||
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
|
||||
"""
|
||||
@@ -146,7 +126,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
|
||||
else:
|
||||
return torch.load(checkpoint_file)
|
||||
|
||||
def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False):
|
||||
def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into
|
||||
this module and its descendants.
|
||||
|
||||
@@ -167,29 +147,22 @@ def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missi
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
def load(module: nn.Module, state_dict, prefix=""):
|
||||
def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
||||
args = (state_dict, prefix, local_metadata, True, sub_missing_keys, [], error_msgs)
|
||||
# Parameters of module and children will start with prefix. We can exit early if there are none in this
|
||||
# state_dict
|
||||
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
|
||||
module._load_from_state_dict(*args)
|
||||
if load_sub_module:
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, state_dict, prefix + name + ".")
|
||||
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, state_dict, prefix + name + ".")
|
||||
|
||||
load(model, state_dict, "")
|
||||
load(model, state_dict, "", load_sub_module)
|
||||
del load
|
||||
|
||||
# deal with missing key
|
||||
if len(missing_keys) > 0:
|
||||
deleted_keys = []
|
||||
for key in missing_keys:
|
||||
if key not in sub_missing_keys:
|
||||
deleted_keys.append(key)
|
||||
for key in deleted_keys:
|
||||
missing_keys.remove(key)
|
||||
missing_keys = missing_keys.append(sub_missing_keys)
|
||||
|
||||
if strict:
|
||||
if len(unexpected_keys) > 0:
|
||||
@@ -417,3 +390,24 @@ def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
||||
weights_name = ".".join(splits)
|
||||
|
||||
return weights_name
|
||||
|
||||
|
||||
def get_base_filenames(variant: str=None, use_safetensors: bool=False):
|
||||
"""
|
||||
generate base weight filenames
|
||||
"""
|
||||
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
|
||||
weights_name = add_variant(weights_name, variant)
|
||||
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
|
||||
save_index_file = add_variant(save_index_file, variant)
|
||||
|
||||
return weights_name, save_index_file
|
||||
|
||||
def get_shard_filename(weights_name: str, idx: int):
|
||||
"""
|
||||
get shard file name
|
||||
"""
|
||||
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
|
||||
shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors")
|
||||
return shard_file
|
||||
Reference in New Issue
Block a user