mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[booster] gemini plugin support shard checkpoint (#3610)
* gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint --------- Co-authored-by: luchen <luchen@luchendeMBP.lan> Co-authored-by: luchen <luchen@luchendeMacBook-Pro.local>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple
|
||||
from typing import List, Mapping, OrderedDict, Optional, Tuple, Iterator
|
||||
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
||||
import re
|
||||
|
||||
@@ -77,55 +77,35 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
|
||||
# ======================================
|
||||
# Helper functions for saving shard file
|
||||
# ======================================
|
||||
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, weights_name: str = WEIGHTS_NAME):
|
||||
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
|
||||
"""
|
||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
given size.
|
||||
"""
|
||||
sharded_state_dicts = []
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
total_size = 0
|
||||
|
||||
for key, weight in state_dict.items():
|
||||
ret_block = None
|
||||
ret_block_size = 0
|
||||
if type(weight) != DTensor:
|
||||
weight_size = calculate_tensor_size(weight)
|
||||
|
||||
# If this weight is going to tip up over the maximal size, we split.
|
||||
if current_block_size + weight_size > max_shard_size:
|
||||
sharded_state_dicts.append(current_block)
|
||||
ret_block = current_block
|
||||
ret_block_size = current_block_size
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
|
||||
current_block[key] = weight
|
||||
current_block_size += weight_size
|
||||
total_size += weight_size
|
||||
|
||||
if ret_block != None:
|
||||
yield ret_block, ret_block_size
|
||||
|
||||
# Add the last block
|
||||
sharded_state_dicts.append(current_block)
|
||||
yield current_block, current_block_size
|
||||
|
||||
# If we only have one shard, we return it
|
||||
if len(sharded_state_dicts) == 1:
|
||||
return {weights_name: sharded_state_dicts[0]}, None
|
||||
|
||||
# Otherwise, let's build the index
|
||||
weight_map = {}
|
||||
shards = {}
|
||||
|
||||
for idx, shard in enumerate(sharded_state_dicts):
|
||||
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
|
||||
shard_file = shard_file.replace(
|
||||
".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
|
||||
)
|
||||
shards[shard_file] = shard
|
||||
for key in shard.keys():
|
||||
weight_map[key] = shard_file
|
||||
|
||||
# Add the metadata
|
||||
metadata = {"total_size": total_size}
|
||||
index = {"metadata": metadata, "weight_map": weight_map}
|
||||
return shards, index
|
||||
|
||||
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
|
||||
"""
|
||||
@@ -146,7 +126,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
|
||||
else:
|
||||
return torch.load(checkpoint_file)
|
||||
|
||||
def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False):
|
||||
def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into
|
||||
this module and its descendants.
|
||||
|
||||
@@ -167,29 +147,22 @@ def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missi
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
def load(module: nn.Module, state_dict, prefix=""):
|
||||
def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
||||
args = (state_dict, prefix, local_metadata, True, sub_missing_keys, [], error_msgs)
|
||||
# Parameters of module and children will start with prefix. We can exit early if there are none in this
|
||||
# state_dict
|
||||
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
|
||||
module._load_from_state_dict(*args)
|
||||
if load_sub_module:
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, state_dict, prefix + name + ".")
|
||||
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, state_dict, prefix + name + ".")
|
||||
|
||||
load(model, state_dict, "")
|
||||
load(model, state_dict, "", load_sub_module)
|
||||
del load
|
||||
|
||||
# deal with missing key
|
||||
if len(missing_keys) > 0:
|
||||
deleted_keys = []
|
||||
for key in missing_keys:
|
||||
if key not in sub_missing_keys:
|
||||
deleted_keys.append(key)
|
||||
for key in deleted_keys:
|
||||
missing_keys.remove(key)
|
||||
missing_keys = missing_keys.append(sub_missing_keys)
|
||||
|
||||
if strict:
|
||||
if len(unexpected_keys) > 0:
|
||||
@@ -417,3 +390,24 @@ def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
||||
weights_name = ".".join(splits)
|
||||
|
||||
return weights_name
|
||||
|
||||
|
||||
def get_base_filenames(variant: str=None, use_safetensors: bool=False):
|
||||
"""
|
||||
generate base weight filenames
|
||||
"""
|
||||
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
|
||||
weights_name = add_variant(weights_name, variant)
|
||||
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
|
||||
save_index_file = add_variant(save_index_file, variant)
|
||||
|
||||
return weights_name, save_index_file
|
||||
|
||||
def get_shard_filename(weights_name: str, idx: int):
|
||||
"""
|
||||
get shard file name
|
||||
"""
|
||||
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
|
||||
shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors")
|
||||
return shard_file
|
Reference in New Issue
Block a user