[checkpoint] refactored the API and added safetensors support (#3427)

* [checkpoint] refactored the API and added safetensors support

* polish code
This commit is contained in:
Frank Lee
2023-04-04 15:23:01 +08:00
committed by GitHub
parent 26b7aac0be
commit 1beb85cc25
9 changed files with 579 additions and 280 deletions

View File

@@ -1,7 +1,6 @@
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Union
from typing import Union
import torch
import torch.nn as nn
@@ -10,7 +9,9 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.interface import ModelWrapper
__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile']
from .utils import has_index_file
__all__ = ['CheckpointIO']
class CheckpointIO(ABC):
@@ -25,15 +26,31 @@ class CheckpointIO(ABC):
>>> # load model from checkpoint
>>> model = checkpoint_io.load_model(model, 'model.pt')
>>>
>>> # save model to checkpoint
>>> # save model to checkpoint, any distributed tensor is gathered by default
>>> checkpoint_io.save_model(model, 'model.pt')
>>>
>>> # if the model contains distributed tensor, and you don't want to gather it
>>> # each rank will save its own shard of the distributed tensor
>>> checkpoint_io.save_model(model, 'model.pt', gather_dtensor=False)
>>>
>>> # save model to sharded checkpoints
>>> checkpoint_io.save_model(model, './checkpoints/', shard=True)
>>>
>>> # save model to sharded and assume we don't want to gather distributed tensors
>>> checkpoint_io.save_model(model, './checkpoints/', shard=True, gather_dtensor=False)
>>>
>>> # Note:
>>> # 1. we don't support loading from distributed tensors, conversion from distributed tensors
>>> # checkpoints to full tensor checkpoint should be done offline via our CLI
>>> # 2. you don't have to specify whether the model is sharded or not when loading the model
>>> # as it will be automatically detected
>>>
>>> # load model from sharded checkpoints
>>> model = checkpoint_io.load_model(model, './checkpoints/')
>>>
>>> # load model from unsharded checkpoints
>>> model = checkpoint_io.load_model(model, './checkpoints/')
>>>
>>> # load optimizer from checkpoint
>>> optimizer = checkpoint_io.load_optimizer(optimizer, 'optimizer.pt')
>>>
@@ -58,21 +75,27 @@ class CheckpointIO(ABC):
1. a file path, e.g. 'model.pt'
2. a path to a json file which defines the index to the sharded checkpoint
3. a path to a folder containing a unique .index.json file for sharded checkpoint
Distributed tensors cannot be loaded directly unless gathered offline via our CLI.
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
"""
# since we only support loaded sharded and unsharded weight format
# containing no distributed tensors, dtensor -> full tensor conversion
# should be done offline via our CLI
# the existence of index file means it is a sharded checkpoint
ckpt_path = Path(checkpoint)
is_sharded = self.is_sharded_checkpoint(ckpt_path)
index_file_exists, index_file_path = has_index_file(checkpoint)
# return the origin model instead of the unwrapped model
origin_model = model
if isinstance(model, ModelWrapper):
model = model.unwrap()
if is_sharded:
self.load_sharded_model(model, ckpt_path, strict)
if index_file_exists:
self.load_sharded_model(model, index_file_path, strict)
else:
self.load_unsharded_model(model, ckpt_path, strict)
self.load_unsharded_model(model, checkpoint, strict)
return origin_model
@@ -80,8 +103,10 @@ class CheckpointIO(ABC):
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
shard: bool = False,
gather_dtensor: bool = True,
prefix: str = None,
size_per_shard: int = 1024):
size_per_shard: int = 1024,
use_safetensors: bool = False):
"""
Save model to checkpoint.
@@ -103,17 +128,19 @@ class CheckpointIO(ABC):
shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure
that the checkpoint path is a directory path instead of a file path.
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
prefix (str): prefix for the model checkpoint file name when shard=True. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
"""
if isinstance(model, ModelWrapper):
model = model.unwrap()
if shard:
self.save_sharded_model(model, checkpoint, prefix, size_per_shard)
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
else:
self.save_unsharded_model(model, checkpoint)
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
"""
@@ -123,22 +150,27 @@ class CheckpointIO(ABC):
optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the
"""
ckpt_path = Path(checkpoint)
is_sharded = self.is_sharded_checkpoint(ckpt_path)
index_file_exists, index_file_path = has_index_file(checkpoint)
if is_sharded:
self.load_sharded_optimizer(optimizer, ckpt_path)
if Path(checkpoint).is_dir() and not index_file_exists:
# if the checkpoint is a directory and there is no index file, raise error
raise ValueError(f'Cannot find index file in {checkpoint}')
if index_file_exists:
# the existence of index file means it is a sharded checkpoint
self.load_sharded_optimizer(optimizer, index_file_path)
else:
self.load_unsharded_optimizer(optimizer, ckpt_path)
self.load_unsharded_optimizer(optimizer, checkpoint)
def save_optimizer(self,
optimizer: Optimizer,
checkpoint: str,
shard: bool = False,
gather_dtensor=True,
prefix: str = None,
size_per_shard: int = 1024):
"""
Save optimizer to checkpoint.
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
Args:
optimizer (Optimizer): optimizer to be saved.
@@ -148,30 +180,33 @@ class CheckpointIO(ABC):
3. a path to a folder containing a unique .index.json file for sharded checkpoint
shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
multiple files. The optimizer shards will be specificed by a `optimizer.index.json` file.
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
"""
if shard:
self.save_sharded_optimizer(optimizer, checkpoint, prefix, size_per_shard)
self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
else:
self.save_unsharded_optimizer(optimizer, checkpoint)
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
# ========================================================
# Abstract methods for model loading/saving implementation
# ========================================================
@abstractmethod
def load_sharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: bool):
"""
Load model from sharded checkpoint.
Args:
model (nn.Module): model to be loaded.
checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
"""
pass
@abstractmethod
def load_unsharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
"""
Load model from unsharded checkpoint.
@@ -184,26 +219,31 @@ class CheckpointIO(ABC):
pass
@abstractmethod
def save_sharded_model(self, model: nn.Module, checkpoint: Path, prefix: str, size_per_shard: int):
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: str,
size_per_shard: int, use_safetensors: bool):
"""
Save model to sharded checkpoint.
Args:
model (nn.Module): model to be saved.
checkpoint (Path): checkpoint path. It should be a directory path.
checkpoint (str): checkpoint path. It should be a directory path.
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
prefix (str): prefix for the model checkpoint.
size_per_shard (int): size per shard in MB.
use_safetensors (bool): whether to use safe tensors.
"""
pass
@abstractmethod
def save_unsharded_model(self, model: nn.Module, checkpoint: Path):
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Save model to unsharded checkpoint.
Args:
model (nn.Module): model to be saved.
checkpoint (Path): checkpoint path. It should be a single file path pointing to a model weight binary.
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
use_safetensors (bool): whether to use safe tensors.
"""
pass
@@ -212,13 +252,13 @@ class CheckpointIO(ABC):
# ========================================================
@abstractmethod
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
"""
Load optimizer from sharded checkpoint.
Args:
optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
prefix (str): prefix for the optimizer checkpoint.
size_per_shard (int): size per shard in MB.
"""
@@ -236,26 +276,29 @@ class CheckpointIO(ABC):
pass
@abstractmethod
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str,
size_per_shard: int):
"""
Save optimizer to sharded checkpoint.
Args:
optimizer (Optimizer): optimizer to be saved.
checkpoint (Path): checkpoint path. It should be a directory path.
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
prefix (str): prefix for the optimizer checkpoint.
size_per_shard (int): size per shard in MB.
"""
pass
@abstractmethod
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
"""
Save optimizer to unsharded checkpoint.
Args:
optimizer (Optimizer): optimizer to be saved.
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
"""
pass
@@ -264,7 +307,6 @@ class CheckpointIO(ABC):
# as this is quite standard, there is no need
# to make them abstract
# ============================================
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save lr scheduler to checkpoint.
@@ -285,231 +327,3 @@ class CheckpointIO(ABC):
"""
state_dict = torch.load(checkpoint)
lr_scheduler.load_state_dict(state_dict)
# ========================================
# Helper functions for loading state dict
# ========================================
def get_sharded_checkpoint_index_file(self, checkpoint_path: Path):
"""
Get the index file path for a sharded checkpoint.
Args:
checkpoint_path (Path): path to the checkpoint.
Returns:
Path: path to the index file.
"""
if checkpoint_path.is_file():
# check if it is .index.json
if checkpoint_path.name.endswith('.index.json'):
return checkpoint_path
else:
raise ValueError(f'Invalid checkpoint path: {checkpoint_path}. ')
elif checkpoint_path.is_dir():
# check if there is only one a file ending with .index.json in this directory
index_files = list(checkpoint_path.glob('*.index.json'))
if len(index_files) == 1:
return index_files[0]
else:
raise ValueError(f'Found {len(index_files)} index files in {checkpoint_path}. ')
def is_sharded_checkpoint(self, checkpoint_path: Path):
"""
Check whether the checkpoint is sharded.
Args:
checkpoint (str): checkpoint path.
Returns:
bool: whether the checkpoint is sharded.
"""
if checkpoint_path.is_file():
# check if it is .index.json
if checkpoint_path.name.endswith('.index.json'):
return True
else:
return False
elif checkpoint_path.is_dir():
# check if there is only one a file ending with .index.json in this directory
index_files = list(checkpoint_path.glob('*.index.json'))
if len(index_files) == 1:
return True
else:
raise ValueError(f'Found {len(index_files)} index files in {checkpoint_path}. ')
def get_checkpoint_shard_filenames(self, index_file_path: Path):
"""
Get checkpoint shard filenames from a json file.
Args:
index_file_path (Path): path to the json file.
Returns:
list: checkpoint shard filenames.
"""
with open(str(index_file_path), 'r') as f:
shard_filenames = json.load(f)
if "weight_map" in index:
index = index["weight_map"]
checkpoint_root_path = index_file_path.absolute().parent
# read the checkpoint file list from the json file and get a list of unique file names
checkpoint_files = sorted(list(set(index.values())))
# get the absolute paths for all checkpoint files
checkpoint_files = [checkpoint_root_path.joinpath(f) for f in checkpoint_files]
return shard_filenames
def load_safetensors_state_dict(self, *args, **kwargs):
"""
Load safetensors state dict from checkpoint.
"""
# TODO(FrankLeeeee): support huggingface safetensors
raise NotImplementedError("This method is not implemented to support safe tensors")
def load_state_dict(self, checkpoint_file_path: Path):
"""
Load state dict from checkpoint.
Args:
checkpoint_file_path (Path): path to the checkpoint file.
Returns:
dict: state dict.
"""
return torch.load(str(checkpoint_file_path))
# ======================================
# Helper functions for saving state dict
# ======================================
def save_safetensors_state_dict(self, *args, **kwargs):
"""
Save safetensors state dict to checkpoint.
"""
# TODO(FrankLeeeee): support huggingface safetensors
raise NotImplementedError("This method is not implemented to support safe tensors")
def generate_checkpoint_shard_file_name(self, index: int, total_number: int, prefix: str = None):
"""
Generate checkpoint shard file name.
Args:
index (int): index of the shard.
total_number (int): total number of shards.
prefix (str): prefix of the shard file name. Default: None.
"""
if prefix is None:
return f"{index}-of-{total_number}.bin"
else:
return f"{prefix}-{index}-of-{total_number}.bin"
def save_checkpoint(self, state_dict: dict, checkpoint_file_path: Path):
"""
Save state dict to checkpoint.
Args:
state_dict (dict): state dict.
checkpoint_file_path (Path): path to the checkpoint file.
"""
torch.save(state_dict, str(checkpoint_file_path))
def save_state_dict_as_shard(self, state_dict: dict, index: int, total_number: int, prefix: str,
checkpoint_path: Path):
"""
Save state dict as shard.
Args:
state_dict (dict): state dict.
checkpoint_path (Path): path to the checkpoint file.
"""
# generate the shard name
shard_file_name = self.generate_checkpoint_shard_file_name(index, total_number, prefix)
shard_file_path = checkpoint_path.joinpath(shard_file_name)
# save the shard
self.save_checkpoint(state_dict, shard_file_path)
def calculate_param_size(self, param: torch.Tensor):
"""
Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size.
If so, a new shard should be created.
ArgsL
param (torch.Tensor): parameter tensor.
"""
# TODO(FrankLeeeee): check if this tensor is a DTensor, compute its global size if so
return param.numel() * param.element_size() / 1024 / 1024
class ShardCheckpointIndexFile:
"""
This class is a data structure to keep the content in the index.json file for sharded checkpoint.
Example:
>>> index = ShardCheckpointIndexFile()
>>> index.load('index.json')
>>> index.append_metadata('model_type', 'bert')
>>> index.append_weight_map('bert.embeddings.word_embeddings.weight', 'bert.embeddings.word_embeddings.weight-0-of-2.bin')
>>> index.export('index.json')
"""
def __init__(self) -> None:
self.metadata: dict = dict()
self.weight_map: dict = dict()
def load(self, json_path: str):
"""
Load the index file from a json file.
Args:
json_path (str): path to the json file.
"""
# load the json file
with open(json_path, 'r') as f:
index = json.load(f)
# assign attributes if exists
if "metadata" in index:
self.metadata = index["metadata"]
if "weight_map" in index:
self.weight_map = index["weight_map"]
def export(self, json_path: str):
"""
Export the index file to a json file.
Args:
json_path (str): path to the json file.
"""
# create the index file
index = dict()
index["metadata"] = self.metadata
index["weight_map"] = self.weight_map
# export the index file
with open(json_path, 'w') as f:
json.dump(index, f, indent=4)
def append_weight_map(self, param_name: str, shard_file: str):
"""
Append a weight map entry to the index file.
Args:
param_name (str): name of the parameter.
shard_file (str): name of the shard file.
"""
self.weight_map[param_name] = shard_file
def append_meta_data(self, name: str, val: Any):
"""
Append a metadata entry to the index file.
Args:
name (str): name of the metadata.
val (Any): value of the metadata.
"""
self.metadata[name] = val