mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[checkpoint] refactored the API and added safetensors support (#3427)
* [checkpoint] refactored the API and added safetensors support * polish code
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -10,7 +9,9 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
|
||||
from colossalai.interface import ModelWrapper
|
||||
|
||||
__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile']
|
||||
from .utils import has_index_file
|
||||
|
||||
__all__ = ['CheckpointIO']
|
||||
|
||||
|
||||
class CheckpointIO(ABC):
|
||||
@@ -25,15 +26,31 @@ class CheckpointIO(ABC):
|
||||
>>> # load model from checkpoint
|
||||
>>> model = checkpoint_io.load_model(model, 'model.pt')
|
||||
>>>
|
||||
>>> # save model to checkpoint
|
||||
>>> # save model to checkpoint, any distributed tensor is gathered by default
|
||||
>>> checkpoint_io.save_model(model, 'model.pt')
|
||||
>>>
|
||||
>>> # if the model contains distributed tensor, and you don't want to gather it
|
||||
>>> # each rank will save its own shard of the distributed tensor
|
||||
>>> checkpoint_io.save_model(model, 'model.pt', gather_dtensor=False)
|
||||
>>>
|
||||
>>> # save model to sharded checkpoints
|
||||
>>> checkpoint_io.save_model(model, './checkpoints/', shard=True)
|
||||
>>>
|
||||
>>> # save model to sharded and assume we don't want to gather distributed tensors
|
||||
>>> checkpoint_io.save_model(model, './checkpoints/', shard=True, gather_dtensor=False)
|
||||
>>>
|
||||
>>> # Note:
|
||||
>>> # 1. we don't support loading from distributed tensors, conversion from distributed tensors
|
||||
>>> # checkpoints to full tensor checkpoint should be done offline via our CLI
|
||||
>>> # 2. you don't have to specify whether the model is sharded or not when loading the model
|
||||
>>> # as it will be automatically detected
|
||||
>>>
|
||||
>>> # load model from sharded checkpoints
|
||||
>>> model = checkpoint_io.load_model(model, './checkpoints/')
|
||||
>>>
|
||||
>>> # load model from unsharded checkpoints
|
||||
>>> model = checkpoint_io.load_model(model, './checkpoints/')
|
||||
>>>
|
||||
>>> # load optimizer from checkpoint
|
||||
>>> optimizer = checkpoint_io.load_optimizer(optimizer, 'optimizer.pt')
|
||||
>>>
|
||||
@@ -58,21 +75,27 @@ class CheckpointIO(ABC):
|
||||
1. a file path, e.g. 'model.pt'
|
||||
2. a path to a json file which defines the index to the sharded checkpoint
|
||||
3. a path to a folder containing a unique .index.json file for sharded checkpoint
|
||||
Distributed tensors cannot be loaded directly unless gathered offline via our CLI.
|
||||
strict (bool): whether to strictly enforce that the param name in
|
||||
the checkpoint match the keys returned by this module's.
|
||||
"""
|
||||
# since we only support loaded sharded and unsharded weight format
|
||||
# containing no distributed tensors, dtensor -> full tensor conversion
|
||||
# should be done offline via our CLI
|
||||
# the existence of index file means it is a sharded checkpoint
|
||||
ckpt_path = Path(checkpoint)
|
||||
is_sharded = self.is_sharded_checkpoint(ckpt_path)
|
||||
index_file_exists, index_file_path = has_index_file(checkpoint)
|
||||
|
||||
# return the origin model instead of the unwrapped model
|
||||
origin_model = model
|
||||
|
||||
if isinstance(model, ModelWrapper):
|
||||
model = model.unwrap()
|
||||
|
||||
if is_sharded:
|
||||
self.load_sharded_model(model, ckpt_path, strict)
|
||||
if index_file_exists:
|
||||
self.load_sharded_model(model, index_file_path, strict)
|
||||
else:
|
||||
self.load_unsharded_model(model, ckpt_path, strict)
|
||||
self.load_unsharded_model(model, checkpoint, strict)
|
||||
|
||||
return origin_model
|
||||
|
||||
@@ -80,8 +103,10 @@ class CheckpointIO(ABC):
|
||||
model: Union[nn.Module, ModelWrapper],
|
||||
checkpoint: str,
|
||||
shard: bool = False,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024):
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
"""
|
||||
Save model to checkpoint.
|
||||
|
||||
@@ -103,17 +128,19 @@ class CheckpointIO(ABC):
|
||||
shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
|
||||
multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure
|
||||
that the checkpoint path is a directory path instead of a file path.
|
||||
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
|
||||
prefix (str): prefix for the model checkpoint file name when shard=True. Default: None.
|
||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
|
||||
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
|
||||
"""
|
||||
|
||||
if isinstance(model, ModelWrapper):
|
||||
model = model.unwrap()
|
||||
|
||||
if shard:
|
||||
self.save_sharded_model(model, checkpoint, prefix, size_per_shard)
|
||||
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
|
||||
else:
|
||||
self.save_unsharded_model(model, checkpoint)
|
||||
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
"""
|
||||
@@ -123,22 +150,27 @@ class CheckpointIO(ABC):
|
||||
optimizer (Optimizer): optimizer to be loaded.
|
||||
checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the
|
||||
"""
|
||||
ckpt_path = Path(checkpoint)
|
||||
is_sharded = self.is_sharded_checkpoint(ckpt_path)
|
||||
index_file_exists, index_file_path = has_index_file(checkpoint)
|
||||
|
||||
if is_sharded:
|
||||
self.load_sharded_optimizer(optimizer, ckpt_path)
|
||||
if Path(checkpoint).is_dir() and not index_file_exists:
|
||||
# if the checkpoint is a directory and there is no index file, raise error
|
||||
raise ValueError(f'Cannot find index file in {checkpoint}')
|
||||
|
||||
if index_file_exists:
|
||||
# the existence of index file means it is a sharded checkpoint
|
||||
self.load_sharded_optimizer(optimizer, index_file_path)
|
||||
else:
|
||||
self.load_unsharded_optimizer(optimizer, ckpt_path)
|
||||
self.load_unsharded_optimizer(optimizer, checkpoint)
|
||||
|
||||
def save_optimizer(self,
|
||||
optimizer: Optimizer,
|
||||
checkpoint: str,
|
||||
shard: bool = False,
|
||||
gather_dtensor=True,
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024):
|
||||
"""
|
||||
Save optimizer to checkpoint.
|
||||
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): optimizer to be saved.
|
||||
@@ -148,30 +180,33 @@ class CheckpointIO(ABC):
|
||||
3. a path to a folder containing a unique .index.json file for sharded checkpoint
|
||||
shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
|
||||
multiple files. The optimizer shards will be specificed by a `optimizer.index.json` file.
|
||||
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
|
||||
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
|
||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
|
||||
"""
|
||||
if shard:
|
||||
self.save_sharded_optimizer(optimizer, checkpoint, prefix, size_per_shard)
|
||||
self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
|
||||
else:
|
||||
self.save_unsharded_optimizer(optimizer, checkpoint)
|
||||
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
|
||||
|
||||
# ========================================================
|
||||
# Abstract methods for model loading/saving implementation
|
||||
# ========================================================
|
||||
@abstractmethod
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
|
||||
def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: bool):
|
||||
"""
|
||||
Load model from sharded checkpoint.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model to be loaded.
|
||||
checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||
strict (bool): whether to strictly enforce that the param name in
|
||||
the checkpoint match the keys returned by this module's.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
||||
"""
|
||||
Load model from unsharded checkpoint.
|
||||
|
||||
@@ -184,26 +219,31 @@ class CheckpointIO(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: Path, prefix: str, size_per_shard: int):
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: str,
|
||||
size_per_shard: int, use_safetensors: bool):
|
||||
"""
|
||||
Save model to sharded checkpoint.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model to be saved.
|
||||
checkpoint (Path): checkpoint path. It should be a directory path.
|
||||
checkpoint (str): checkpoint path. It should be a directory path.
|
||||
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
|
||||
prefix (str): prefix for the model checkpoint.
|
||||
size_per_shard (int): size per shard in MB.
|
||||
use_safetensors (bool): whether to use safe tensors.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: Path):
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
"""
|
||||
Save model to unsharded checkpoint.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model to be saved.
|
||||
checkpoint (Path): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
|
||||
use_safetensors (bool): whether to use safe tensors.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -212,13 +252,13 @@ class CheckpointIO(ABC):
|
||||
# ========================================================
|
||||
|
||||
@abstractmethod
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
|
||||
"""
|
||||
Load optimizer from sharded checkpoint.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): optimizer to be loaded.
|
||||
checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||
prefix (str): prefix for the optimizer checkpoint.
|
||||
size_per_shard (int): size per shard in MB.
|
||||
"""
|
||||
@@ -236,26 +276,29 @@ class CheckpointIO(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
|
||||
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str,
|
||||
size_per_shard: int):
|
||||
"""
|
||||
Save optimizer to sharded checkpoint.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): optimizer to be saved.
|
||||
checkpoint (Path): checkpoint path. It should be a directory path.
|
||||
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
|
||||
prefix (str): prefix for the optimizer checkpoint.
|
||||
size_per_shard (int): size per shard in MB.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
|
||||
"""
|
||||
Save optimizer to unsharded checkpoint.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): optimizer to be saved.
|
||||
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -264,7 +307,6 @@ class CheckpointIO(ABC):
|
||||
# as this is quite standard, there is no need
|
||||
# to make them abstract
|
||||
# ============================================
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
"""
|
||||
Save lr scheduler to checkpoint.
|
||||
@@ -285,231 +327,3 @@ class CheckpointIO(ABC):
|
||||
"""
|
||||
state_dict = torch.load(checkpoint)
|
||||
lr_scheduler.load_state_dict(state_dict)
|
||||
|
||||
# ========================================
|
||||
# Helper functions for loading state dict
|
||||
# ========================================
|
||||
|
||||
def get_sharded_checkpoint_index_file(self, checkpoint_path: Path):
|
||||
"""
|
||||
Get the index file path for a sharded checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_path (Path): path to the checkpoint.
|
||||
|
||||
Returns:
|
||||
Path: path to the index file.
|
||||
"""
|
||||
if checkpoint_path.is_file():
|
||||
# check if it is .index.json
|
||||
if checkpoint_path.name.endswith('.index.json'):
|
||||
return checkpoint_path
|
||||
else:
|
||||
raise ValueError(f'Invalid checkpoint path: {checkpoint_path}. ')
|
||||
elif checkpoint_path.is_dir():
|
||||
# check if there is only one a file ending with .index.json in this directory
|
||||
index_files = list(checkpoint_path.glob('*.index.json'))
|
||||
if len(index_files) == 1:
|
||||
return index_files[0]
|
||||
else:
|
||||
raise ValueError(f'Found {len(index_files)} index files in {checkpoint_path}. ')
|
||||
|
||||
def is_sharded_checkpoint(self, checkpoint_path: Path):
|
||||
"""
|
||||
Check whether the checkpoint is sharded.
|
||||
|
||||
Args:
|
||||
checkpoint (str): checkpoint path.
|
||||
|
||||
Returns:
|
||||
bool: whether the checkpoint is sharded.
|
||||
"""
|
||||
if checkpoint_path.is_file():
|
||||
# check if it is .index.json
|
||||
if checkpoint_path.name.endswith('.index.json'):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
elif checkpoint_path.is_dir():
|
||||
# check if there is only one a file ending with .index.json in this directory
|
||||
index_files = list(checkpoint_path.glob('*.index.json'))
|
||||
if len(index_files) == 1:
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f'Found {len(index_files)} index files in {checkpoint_path}. ')
|
||||
|
||||
def get_checkpoint_shard_filenames(self, index_file_path: Path):
|
||||
"""
|
||||
Get checkpoint shard filenames from a json file.
|
||||
|
||||
Args:
|
||||
index_file_path (Path): path to the json file.
|
||||
|
||||
Returns:
|
||||
list: checkpoint shard filenames.
|
||||
"""
|
||||
with open(str(index_file_path), 'r') as f:
|
||||
shard_filenames = json.load(f)
|
||||
|
||||
if "weight_map" in index:
|
||||
index = index["weight_map"]
|
||||
|
||||
checkpoint_root_path = index_file_path.absolute().parent
|
||||
|
||||
# read the checkpoint file list from the json file and get a list of unique file names
|
||||
checkpoint_files = sorted(list(set(index.values())))
|
||||
|
||||
# get the absolute paths for all checkpoint files
|
||||
checkpoint_files = [checkpoint_root_path.joinpath(f) for f in checkpoint_files]
|
||||
return shard_filenames
|
||||
|
||||
def load_safetensors_state_dict(self, *args, **kwargs):
|
||||
"""
|
||||
Load safetensors state dict from checkpoint.
|
||||
"""
|
||||
# TODO(FrankLeeeee): support huggingface safetensors
|
||||
raise NotImplementedError("This method is not implemented to support safe tensors")
|
||||
|
||||
def load_state_dict(self, checkpoint_file_path: Path):
|
||||
"""
|
||||
Load state dict from checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_file_path (Path): path to the checkpoint file.
|
||||
|
||||
Returns:
|
||||
dict: state dict.
|
||||
"""
|
||||
return torch.load(str(checkpoint_file_path))
|
||||
|
||||
# ======================================
|
||||
# Helper functions for saving state dict
|
||||
# ======================================
|
||||
|
||||
def save_safetensors_state_dict(self, *args, **kwargs):
|
||||
"""
|
||||
Save safetensors state dict to checkpoint.
|
||||
"""
|
||||
# TODO(FrankLeeeee): support huggingface safetensors
|
||||
raise NotImplementedError("This method is not implemented to support safe tensors")
|
||||
|
||||
def generate_checkpoint_shard_file_name(self, index: int, total_number: int, prefix: str = None):
|
||||
"""
|
||||
Generate checkpoint shard file name.
|
||||
|
||||
Args:
|
||||
index (int): index of the shard.
|
||||
total_number (int): total number of shards.
|
||||
prefix (str): prefix of the shard file name. Default: None.
|
||||
"""
|
||||
if prefix is None:
|
||||
return f"{index}-of-{total_number}.bin"
|
||||
else:
|
||||
return f"{prefix}-{index}-of-{total_number}.bin"
|
||||
|
||||
def save_checkpoint(self, state_dict: dict, checkpoint_file_path: Path):
|
||||
"""
|
||||
Save state dict to checkpoint.
|
||||
|
||||
Args:
|
||||
state_dict (dict): state dict.
|
||||
checkpoint_file_path (Path): path to the checkpoint file.
|
||||
"""
|
||||
torch.save(state_dict, str(checkpoint_file_path))
|
||||
|
||||
def save_state_dict_as_shard(self, state_dict: dict, index: int, total_number: int, prefix: str,
|
||||
checkpoint_path: Path):
|
||||
"""
|
||||
Save state dict as shard.
|
||||
|
||||
Args:
|
||||
state_dict (dict): state dict.
|
||||
checkpoint_path (Path): path to the checkpoint file.
|
||||
"""
|
||||
# generate the shard name
|
||||
shard_file_name = self.generate_checkpoint_shard_file_name(index, total_number, prefix)
|
||||
shard_file_path = checkpoint_path.joinpath(shard_file_name)
|
||||
|
||||
# save the shard
|
||||
self.save_checkpoint(state_dict, shard_file_path)
|
||||
|
||||
def calculate_param_size(self, param: torch.Tensor):
|
||||
"""
|
||||
Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size.
|
||||
If so, a new shard should be created.
|
||||
|
||||
ArgsL
|
||||
param (torch.Tensor): parameter tensor.
|
||||
"""
|
||||
# TODO(FrankLeeeee): check if this tensor is a DTensor, compute its global size if so
|
||||
return param.numel() * param.element_size() / 1024 / 1024
|
||||
|
||||
|
||||
class ShardCheckpointIndexFile:
|
||||
"""
|
||||
This class is a data structure to keep the content in the index.json file for sharded checkpoint.
|
||||
|
||||
Example:
|
||||
>>> index = ShardCheckpointIndexFile()
|
||||
>>> index.load('index.json')
|
||||
>>> index.append_metadata('model_type', 'bert')
|
||||
>>> index.append_weight_map('bert.embeddings.word_embeddings.weight', 'bert.embeddings.word_embeddings.weight-0-of-2.bin')
|
||||
>>> index.export('index.json')
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.metadata: dict = dict()
|
||||
self.weight_map: dict = dict()
|
||||
|
||||
def load(self, json_path: str):
|
||||
"""
|
||||
Load the index file from a json file.
|
||||
|
||||
Args:
|
||||
json_path (str): path to the json file.
|
||||
"""
|
||||
# load the json file
|
||||
with open(json_path, 'r') as f:
|
||||
index = json.load(f)
|
||||
|
||||
# assign attributes if exists
|
||||
if "metadata" in index:
|
||||
self.metadata = index["metadata"]
|
||||
if "weight_map" in index:
|
||||
self.weight_map = index["weight_map"]
|
||||
|
||||
def export(self, json_path: str):
|
||||
"""
|
||||
Export the index file to a json file.
|
||||
|
||||
Args:
|
||||
json_path (str): path to the json file.
|
||||
"""
|
||||
# create the index file
|
||||
index = dict()
|
||||
index["metadata"] = self.metadata
|
||||
index["weight_map"] = self.weight_map
|
||||
|
||||
# export the index file
|
||||
with open(json_path, 'w') as f:
|
||||
json.dump(index, f, indent=4)
|
||||
|
||||
def append_weight_map(self, param_name: str, shard_file: str):
|
||||
"""
|
||||
Append a weight map entry to the index file.
|
||||
|
||||
Args:
|
||||
param_name (str): name of the parameter.
|
||||
shard_file (str): name of the shard file.
|
||||
"""
|
||||
self.weight_map[param_name] = shard_file
|
||||
|
||||
def append_meta_data(self, name: str, val: Any):
|
||||
"""
|
||||
Append a metadata entry to the index file.
|
||||
|
||||
Args:
|
||||
name (str): name of the metadata.
|
||||
val (Any): value of the metadata.
|
||||
"""
|
||||
self.metadata[name] = val
|
||||
|
Reference in New Issue
Block a user