mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[booster] gemini plugin support shard checkpoint (#3610)
* gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint --------- Co-authored-by: luchen <luchen@luchendeMBP.lan> Co-authored-by: luchen <luchen@luchendeMacBook-Pro.local>
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
from pathlib import Path
|
||||
from functools import reduce
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import gc
|
||||
from typing import Optional
|
||||
from typing import Optional, Iterator, OrderedDict, Tuple
|
||||
|
||||
from .checkpoint_io_base import CheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
@@ -18,10 +18,9 @@ from .utils import (
|
||||
shard_checkpoint,
|
||||
load_shard_state_dict,
|
||||
load_state_dict_into_model,
|
||||
add_variant
|
||||
get_shard_filename,
|
||||
get_base_filenames
|
||||
)
|
||||
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME
|
||||
|
||||
|
||||
__all__ = ['GeneralCheckpointIO']
|
||||
|
||||
@@ -85,30 +84,32 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
|
||||
# shard checkpoint
|
||||
state_dict = model.state_dict()
|
||||
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
|
||||
weights_name = add_variant(weights_name, variant)
|
||||
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
|
||||
state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
|
||||
|
||||
# Save the model
|
||||
for shard_file, shard in shards.items():
|
||||
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
|
||||
total_size = 0
|
||||
index_file = CheckpointIndexFile(checkpoint_path)
|
||||
for idx, shard_pair in enumerate(state_dict_shard):
|
||||
shard = shard_pair[0]
|
||||
shard_file = get_shard_filename(weights_name, idx)
|
||||
total_size = total_size + shard_pair[1]
|
||||
for key in shard.keys():
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
|
||||
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors)
|
||||
|
||||
# save index file
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
|
||||
|
||||
save_index_file = os.path.join(checkpoint_path, add_variant(save_index_file, variant))
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
logging.info(
|
||||
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
||||
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
|
||||
f"The model is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False):
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False,
|
||||
use_safetensors: bool = False, load_sub_module: bool = True):
|
||||
"""
|
||||
load shard model, load model from multiple files
|
||||
"""
|
||||
@@ -122,17 +123,21 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
# read checkpoint index file
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames()
|
||||
missing_keys = ckpt_index_file.get_all_param_names()
|
||||
missing_keys = []
|
||||
|
||||
for shard_file in checkpoint_files:
|
||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
|
||||
load_state_dict_into_model(model, state_dict, missing_keys, strict)
|
||||
load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)
|
||||
del state_dict
|
||||
gc.collect()
|
||||
|
||||
if strict and len(missing_keys) > 0:
|
||||
error_msgs = 'Missing key(s) in state_dict: {}. '.format(
|
||||
', '.join('"{}"'.format(k) for k in missing_keys))
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
if strict:
|
||||
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
|
||||
if len(remain_keys) > 0:
|
||||
error_msgs = 'Missing key(s) in state_dict: {}. '.format(
|
||||
', '.join('"{}"'.format(k) for k in missing_keys))
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user