From 1beb85cc25f35d51083bde6fbaa99a5c4c7fd387 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Tue, 4 Apr 2023 15:23:01 +0800 Subject: [PATCH] [checkpoint] refactored the API and added safetensors support (#3427) * [checkpoint] refactored the API and added safetensors support * polish code --- colossalai/booster/plugin/torch_ddp_plugin.py | 4 +- colossalai/checkpoint_io/__init__.py | 5 +- .../checkpoint_io/checkpoint_io_base.py | 332 ++++-------------- .../checkpoint_io/general_checkpoint_io.py | 53 ++- colossalai/checkpoint_io/index_file.py | 150 ++++++++ colossalai/checkpoint_io/utils.py | 278 +++++++++++++++ requirements/requirements.txt | 1 + .../test_plugin/test_torch_ddp_plugin.py | 23 ++ .../test_general_checkpoint_io.py | 13 +- 9 files changed, 579 insertions(+), 280 deletions(-) create mode 100644 colossalai/checkpoint_io/index_file.py create mode 100644 colossalai/checkpoint_io/utils.py diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index d7f3d22d9..e2abe11ba 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -33,7 +33,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): # the model should be unwrapped in self.load_model via ModelWrapper.unwrap return super().load_unsharded_model(model, checkpoint, strict=strict) - def save_unsharded_model(self, model: nn.Module, checkpoint: str): + def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool): """ Save model to checkpoint but only on master process. """ @@ -41,7 +41,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): if self.coordinator.is_master(): super().save_unsharded_model(model, checkpoint) - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str): + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): """ Save optimizer to checkpoint but only on master process. """ diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py index 3cec630b2..c25048e25 100644 --- a/colossalai/checkpoint_io/__init__.py +++ b/colossalai/checkpoint_io/__init__.py @@ -1,4 +1,5 @@ -from .checkpoint_io_base import CheckpointIO, ShardCheckpointIndexFile +from .checkpoint_io_base import CheckpointIO from .general_checkpoint_io import GeneralCheckpointIO +from .index_file import CheckpointIndexFile -__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile', 'GeneralCheckpointIO'] +__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO'] diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index d6eef7a96..b91b00831 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -1,7 +1,6 @@ -import json from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Union +from typing import Union import torch import torch.nn as nn @@ -10,7 +9,9 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from colossalai.interface import ModelWrapper -__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile'] +from .utils import has_index_file + +__all__ = ['CheckpointIO'] class CheckpointIO(ABC): @@ -25,15 +26,31 @@ class CheckpointIO(ABC): >>> # load model from checkpoint >>> model = checkpoint_io.load_model(model, 'model.pt') >>> - >>> # save model to checkpoint + >>> # save model to checkpoint, any distributed tensor is gathered by default >>> checkpoint_io.save_model(model, 'model.pt') >>> + >>> # if the model contains distributed tensor, and you don't want to gather it + >>> # each rank will save its own shard of the distributed tensor + >>> checkpoint_io.save_model(model, 'model.pt', gather_dtensor=False) + >>> >>> # save model to sharded checkpoints >>> checkpoint_io.save_model(model, './checkpoints/', shard=True) >>> + >>> # save model to sharded and assume we don't want to gather distributed tensors + >>> checkpoint_io.save_model(model, './checkpoints/', shard=True, gather_dtensor=False) + >>> + >>> # Note: + >>> # 1. we don't support loading from distributed tensors, conversion from distributed tensors + >>> # checkpoints to full tensor checkpoint should be done offline via our CLI + >>> # 2. you don't have to specify whether the model is sharded or not when loading the model + >>> # as it will be automatically detected + >>> >>> # load model from sharded checkpoints >>> model = checkpoint_io.load_model(model, './checkpoints/') >>> + >>> # load model from unsharded checkpoints + >>> model = checkpoint_io.load_model(model, './checkpoints/') + >>> >>> # load optimizer from checkpoint >>> optimizer = checkpoint_io.load_optimizer(optimizer, 'optimizer.pt') >>> @@ -58,21 +75,27 @@ class CheckpointIO(ABC): 1. a file path, e.g. 'model.pt' 2. a path to a json file which defines the index to the sharded checkpoint 3. a path to a folder containing a unique .index.json file for sharded checkpoint + Distributed tensors cannot be loaded directly unless gathered offline via our CLI. strict (bool): whether to strictly enforce that the param name in the checkpoint match the keys returned by this module's. """ + # since we only support loaded sharded and unsharded weight format + # containing no distributed tensors, dtensor -> full tensor conversion + # should be done offline via our CLI + # the existence of index file means it is a sharded checkpoint ckpt_path = Path(checkpoint) - is_sharded = self.is_sharded_checkpoint(ckpt_path) + index_file_exists, index_file_path = has_index_file(checkpoint) + # return the origin model instead of the unwrapped model origin_model = model if isinstance(model, ModelWrapper): model = model.unwrap() - if is_sharded: - self.load_sharded_model(model, ckpt_path, strict) + if index_file_exists: + self.load_sharded_model(model, index_file_path, strict) else: - self.load_unsharded_model(model, ckpt_path, strict) + self.load_unsharded_model(model, checkpoint, strict) return origin_model @@ -80,8 +103,10 @@ class CheckpointIO(ABC): model: Union[nn.Module, ModelWrapper], checkpoint: str, shard: bool = False, + gather_dtensor: bool = True, prefix: str = None, - size_per_shard: int = 1024): + size_per_shard: int = 1024, + use_safetensors: bool = False): """ Save model to checkpoint. @@ -103,17 +128,19 @@ class CheckpointIO(ABC): shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure that the checkpoint path is a directory path instead of a file path. + gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True. prefix (str): prefix for the model checkpoint file name when shard=True. Default: None. size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True. + use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved """ if isinstance(model, ModelWrapper): model = model.unwrap() if shard: - self.save_sharded_model(model, checkpoint, prefix, size_per_shard) + self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors) else: - self.save_unsharded_model(model, checkpoint) + self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) def load_optimizer(self, optimizer: Optimizer, checkpoint: str): """ @@ -123,22 +150,27 @@ class CheckpointIO(ABC): optimizer (Optimizer): optimizer to be loaded. checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the """ - ckpt_path = Path(checkpoint) - is_sharded = self.is_sharded_checkpoint(ckpt_path) + index_file_exists, index_file_path = has_index_file(checkpoint) - if is_sharded: - self.load_sharded_optimizer(optimizer, ckpt_path) + if Path(checkpoint).is_dir() and not index_file_exists: + # if the checkpoint is a directory and there is no index file, raise error + raise ValueError(f'Cannot find index file in {checkpoint}') + + if index_file_exists: + # the existence of index file means it is a sharded checkpoint + self.load_sharded_optimizer(optimizer, index_file_path) else: - self.load_unsharded_optimizer(optimizer, ckpt_path) + self.load_unsharded_optimizer(optimizer, checkpoint) def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, + gather_dtensor=True, prefix: str = None, size_per_shard: int = 1024): """ - Save optimizer to checkpoint. + Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors. Args: optimizer (Optimizer): optimizer to be saved. @@ -148,30 +180,33 @@ class CheckpointIO(ABC): 3. a path to a folder containing a unique .index.json file for sharded checkpoint shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into multiple files. The optimizer shards will be specificed by a `optimizer.index.json` file. + gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True. prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None. size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True. """ if shard: - self.save_sharded_optimizer(optimizer, checkpoint, prefix, size_per_shard) + self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard) else: - self.save_unsharded_optimizer(optimizer, checkpoint) + self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) # ======================================================== # Abstract methods for model loading/saving implementation # ======================================================== @abstractmethod - def load_sharded_model(self, model: nn.Module, checkpoint: Path, strict: bool): + def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: bool): """ Load model from sharded checkpoint. Args: model (nn.Module): model to be loaded. - checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file. + index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file. + strict (bool): whether to strictly enforce that the param name in + the checkpoint match the keys returned by this module's. """ pass @abstractmethod - def load_unsharded_model(self, model: nn.Module, checkpoint: Path, strict: bool): + def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): """ Load model from unsharded checkpoint. @@ -184,26 +219,31 @@ class CheckpointIO(ABC): pass @abstractmethod - def save_sharded_model(self, model: nn.Module, checkpoint: Path, prefix: str, size_per_shard: int): + def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: str, + size_per_shard: int, use_safetensors: bool): """ Save model to sharded checkpoint. Args: model (nn.Module): model to be saved. - checkpoint (Path): checkpoint path. It should be a directory path. + checkpoint (str): checkpoint path. It should be a directory path. + gather_dtensor (bool): whether to gather the distributed tensor to the first device. prefix (str): prefix for the model checkpoint. size_per_shard (int): size per shard in MB. + use_safetensors (bool): whether to use safe tensors. """ pass @abstractmethod - def save_unsharded_model(self, model: nn.Module, checkpoint: Path): + def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): """ Save model to unsharded checkpoint. Args: model (nn.Module): model to be saved. - checkpoint (Path): checkpoint path. It should be a single file path pointing to a model weight binary. + checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. + gather_dtensor (bool): whether to gather the distributed tensor to the first device. + use_safetensors (bool): whether to use safe tensors. """ pass @@ -212,13 +252,13 @@ class CheckpointIO(ABC): # ======================================================== @abstractmethod - def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int): + def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int): """ Load optimizer from sharded checkpoint. Args: optimizer (Optimizer): optimizer to be loaded. - checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file. + index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file. prefix (str): prefix for the optimizer checkpoint. size_per_shard (int): size per shard in MB. """ @@ -236,26 +276,29 @@ class CheckpointIO(ABC): pass @abstractmethod - def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int): + def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, + size_per_shard: int): """ Save optimizer to sharded checkpoint. Args: optimizer (Optimizer): optimizer to be saved. checkpoint (Path): checkpoint path. It should be a directory path. + gather_dtensor (bool): whether to gather the distributed tensor to the first device. prefix (str): prefix for the optimizer checkpoint. size_per_shard (int): size per shard in MB. """ pass @abstractmethod - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool): """ Save optimizer to unsharded checkpoint. Args: optimizer (Optimizer): optimizer to be saved. checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. + gather_dtensor (bool): whether to gather the distributed tensor to the first device. """ pass @@ -264,7 +307,6 @@ class CheckpointIO(ABC): # as this is quite standard, there is no need # to make them abstract # ============================================ - def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): """ Save lr scheduler to checkpoint. @@ -285,231 +327,3 @@ class CheckpointIO(ABC): """ state_dict = torch.load(checkpoint) lr_scheduler.load_state_dict(state_dict) - - # ======================================== - # Helper functions for loading state dict - # ======================================== - - def get_sharded_checkpoint_index_file(self, checkpoint_path: Path): - """ - Get the index file path for a sharded checkpoint. - - Args: - checkpoint_path (Path): path to the checkpoint. - - Returns: - Path: path to the index file. - """ - if checkpoint_path.is_file(): - # check if it is .index.json - if checkpoint_path.name.endswith('.index.json'): - return checkpoint_path - else: - raise ValueError(f'Invalid checkpoint path: {checkpoint_path}. ') - elif checkpoint_path.is_dir(): - # check if there is only one a file ending with .index.json in this directory - index_files = list(checkpoint_path.glob('*.index.json')) - if len(index_files) == 1: - return index_files[0] - else: - raise ValueError(f'Found {len(index_files)} index files in {checkpoint_path}. ') - - def is_sharded_checkpoint(self, checkpoint_path: Path): - """ - Check whether the checkpoint is sharded. - - Args: - checkpoint (str): checkpoint path. - - Returns: - bool: whether the checkpoint is sharded. - """ - if checkpoint_path.is_file(): - # check if it is .index.json - if checkpoint_path.name.endswith('.index.json'): - return True - else: - return False - elif checkpoint_path.is_dir(): - # check if there is only one a file ending with .index.json in this directory - index_files = list(checkpoint_path.glob('*.index.json')) - if len(index_files) == 1: - return True - else: - raise ValueError(f'Found {len(index_files)} index files in {checkpoint_path}. ') - - def get_checkpoint_shard_filenames(self, index_file_path: Path): - """ - Get checkpoint shard filenames from a json file. - - Args: - index_file_path (Path): path to the json file. - - Returns: - list: checkpoint shard filenames. - """ - with open(str(index_file_path), 'r') as f: - shard_filenames = json.load(f) - - if "weight_map" in index: - index = index["weight_map"] - - checkpoint_root_path = index_file_path.absolute().parent - - # read the checkpoint file list from the json file and get a list of unique file names - checkpoint_files = sorted(list(set(index.values()))) - - # get the absolute paths for all checkpoint files - checkpoint_files = [checkpoint_root_path.joinpath(f) for f in checkpoint_files] - return shard_filenames - - def load_safetensors_state_dict(self, *args, **kwargs): - """ - Load safetensors state dict from checkpoint. - """ - # TODO(FrankLeeeee): support huggingface safetensors - raise NotImplementedError("This method is not implemented to support safe tensors") - - def load_state_dict(self, checkpoint_file_path: Path): - """ - Load state dict from checkpoint. - - Args: - checkpoint_file_path (Path): path to the checkpoint file. - - Returns: - dict: state dict. - """ - return torch.load(str(checkpoint_file_path)) - - # ====================================== - # Helper functions for saving state dict - # ====================================== - - def save_safetensors_state_dict(self, *args, **kwargs): - """ - Save safetensors state dict to checkpoint. - """ - # TODO(FrankLeeeee): support huggingface safetensors - raise NotImplementedError("This method is not implemented to support safe tensors") - - def generate_checkpoint_shard_file_name(self, index: int, total_number: int, prefix: str = None): - """ - Generate checkpoint shard file name. - - Args: - index (int): index of the shard. - total_number (int): total number of shards. - prefix (str): prefix of the shard file name. Default: None. - """ - if prefix is None: - return f"{index}-of-{total_number}.bin" - else: - return f"{prefix}-{index}-of-{total_number}.bin" - - def save_checkpoint(self, state_dict: dict, checkpoint_file_path: Path): - """ - Save state dict to checkpoint. - - Args: - state_dict (dict): state dict. - checkpoint_file_path (Path): path to the checkpoint file. - """ - torch.save(state_dict, str(checkpoint_file_path)) - - def save_state_dict_as_shard(self, state_dict: dict, index: int, total_number: int, prefix: str, - checkpoint_path: Path): - """ - Save state dict as shard. - - Args: - state_dict (dict): state dict. - checkpoint_path (Path): path to the checkpoint file. - """ - # generate the shard name - shard_file_name = self.generate_checkpoint_shard_file_name(index, total_number, prefix) - shard_file_path = checkpoint_path.joinpath(shard_file_name) - - # save the shard - self.save_checkpoint(state_dict, shard_file_path) - - def calculate_param_size(self, param: torch.Tensor): - """ - Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size. - If so, a new shard should be created. - - ArgsL - param (torch.Tensor): parameter tensor. - """ - # TODO(FrankLeeeee): check if this tensor is a DTensor, compute its global size if so - return param.numel() * param.element_size() / 1024 / 1024 - - -class ShardCheckpointIndexFile: - """ - This class is a data structure to keep the content in the index.json file for sharded checkpoint. - - Example: - >>> index = ShardCheckpointIndexFile() - >>> index.load('index.json') - >>> index.append_metadata('model_type', 'bert') - >>> index.append_weight_map('bert.embeddings.word_embeddings.weight', 'bert.embeddings.word_embeddings.weight-0-of-2.bin') - >>> index.export('index.json') - """ - - def __init__(self) -> None: - self.metadata: dict = dict() - self.weight_map: dict = dict() - - def load(self, json_path: str): - """ - Load the index file from a json file. - - Args: - json_path (str): path to the json file. - """ - # load the json file - with open(json_path, 'r') as f: - index = json.load(f) - - # assign attributes if exists - if "metadata" in index: - self.metadata = index["metadata"] - if "weight_map" in index: - self.weight_map = index["weight_map"] - - def export(self, json_path: str): - """ - Export the index file to a json file. - - Args: - json_path (str): path to the json file. - """ - # create the index file - index = dict() - index["metadata"] = self.metadata - index["weight_map"] = self.weight_map - - # export the index file - with open(json_path, 'w') as f: - json.dump(index, f, indent=4) - - def append_weight_map(self, param_name: str, shard_file: str): - """ - Append a weight map entry to the index file. - - Args: - param_name (str): name of the parameter. - shard_file (str): name of the shard file. - """ - self.weight_map[param_name] = shard_file - - def append_meta_data(self, name: str, val: Any): - """ - Append a metadata entry to the index file. - - Args: - name (str): name of the metadata. - val (Any): value of the metadata. - """ - self.metadata[name] = val diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index cfabcfa55..c779f4c17 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -4,42 +4,67 @@ import torch.nn as nn from torch.optim import Optimizer from .checkpoint_io_base import CheckpointIO +from .index_file import CheckpointIndexFile +from .utils import has_index_file, load_state_dict, save_state_dict __all__ = ['GeneralCheckpointIO'] class GeneralCheckpointIO(CheckpointIO): - def load_sharded_model(self, model: nn.Module, checkpoint: Path, strict: bool): - index_file_path = self.get_sharded_checkpoint_index_file(checkpoint) + def load_sharded_model(self, model: nn.Module, index_file_path: Path, strict: bool): + # load the index file + index_file = CheckpointIndexFile.from_file(index_file_path) # iterate over the shard checkpoint files # and load each - shard_files = self.get_checkpoint_shard_filenames(index_file_path) - for shard_file in shard_files: - shard_checkpoint = self.load_state_dict(shard_file) + index_file.assert_no_dtensor_checkpoint() + checkpoint_file_list, _ = index_file.get_checkpoint_fileanames() + for shard_file in checkpoint_file_list: + shard_checkpoint = load_state_dict(shard_file) model.load_state_dict(shard_checkpoint, strict=strict) - def load_unsharded_model(self, model: nn.Module, checkpoint: Path, strict: bool): - checkpoint = self.load_state_dict(str(checkpoint)) + def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): + checkpoint = load_state_dict(checkpoint) model.load_state_dict(checkpoint, strict=strict) - def save_sharded_model(self, model: nn.Module, checkpoint: Path, prefix: str, size_per_shard: int): + def save_sharded_model(self, model: nn.Module, checkpoint: Path, gather_dtensor: bool, prefix: str, + size_per_shard: int, use_safetensors: bool): # TODO(FrankLeeeee): implement this method as it can be supported by Huggingface model raise NotImplementedError("Sharded model checkpoint is not supported yet.") - def save_unsharded_model(self, model: nn.Module, checkpoint: Path): - self.save_checkpoint(model.state_dict(), checkpoint) + def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + state_dict = model.state_dict() + + # TODO(FrankLeeeee): add support for gather_dtensor + if gather_dtensor: + pass + + # save the checkpoint + save_state_dict(state_dict, checkpoint, use_safetensors) def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int): raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): - checkpoint = self.load_state_dict(checkpoint) + checkpoint = load_state_dict(checkpoint) optimizer.load_state_dict(checkpoint) - def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int): + def save_sharded_optimizer( + self, + optimizer: Optimizer, + checkpoint: Path, + gather_dtensor: bool, + prefix: str, + size_per_shard: int, + ): raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): - self.save_checkpoint(optimizer.state_dict(), checkpoint) + def save_unsharded_optimizer( + self, + optimizer: Optimizer, + checkpoint: Path, + gather_dtensor: bool, + ): + # TODO(FrankLeeeee): handle distributed tensors + save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False) diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py new file mode 100644 index 000000000..32ff1b762 --- /dev/null +++ b/colossalai/checkpoint_io/index_file.py @@ -0,0 +1,150 @@ +import json +from pathlib import Path +from typing import Any, List, Union + +from .utils import is_dtensor_checkpoint + +__all__ = ['CheckpointIndexFile'] + + +class CheckpointIndexFile: + """ + This class is a data structure to keep the content in the index.json file for sharded checkpoint. + + Example: + >>> index = CheckpointIndexFile.from_file('model.index.json') + >>> index.append_metadata('model_type', 'bert') + >>> index.append_weight_map('bert.embeddings.word_embeddings.weight', 'model_0001-of-0002.bin') + >>> index.export('new_index.json') + """ + + def __init__(self) -> None: + self.root_path = None + self.metadata: dict = dict() + self.weight_map: dict = dict() + + @staticmethod + def from_file(index_path: Union[str, Path]): + """ + Create a CheckpointIndexFile object from a json file. + + Args: + index_path (str): path to the json file. + + Returns: + CheckpointIndexFile: CheckpointIndexFile object. + """ + index = CheckpointIndexFile() + index.load(index_path) + return index + + def load(self, json_path: str): + """ + Load the index file from a json file. + + Args: + json_path (str): path to the json file. + """ + # load the json file + with open(json_path, 'r') as f: + index = json.load(f) + + # assign attributes if exists + if "metadata" in index: + self.metadata = index["metadata"] + if "weight_map" in index: + self.weight_map = index["weight_map"] + + # assign the root directory for the index file + self.root_path = Path(json_path).absolute().parent + + def export(self, json_path: str): + """ + Export the index file to a json file. + + Args: + json_path (str): path to the json file. + """ + # create the index file + index = dict() + index["metadata"] = self.metadata + index["weight_map"] = self.weight_map + + # export the index file + with open(json_path, 'w') as f: + json.dump(index, f, indent=4) + + def append_weight_map(self, param_name: str, shard_file: str): + """ + Append a weight map entry to the index file. + + Args: + param_name (str): name of the parameter. + shard_file (str): name of the shard file. + """ + self.weight_map[param_name] = shard_file + + def append_meta_data(self, name: str, val: Any): + """ + Append a metadata entry to the index file. + + Args: + name (str): name of the metadata. + val (Any): value of the metadata. + """ + self.metadata[name] = val + + def contains_dtensor(self): + """ + Check if the index file contains any distributed tensor. The distributed tensors will be stored in + `dtensor/module.linear.weight.*.bin` or `dtensor/module.linear.weight.*.safetensors` in the weight map. + + Returns: + bool: True if the index file contains any distributed tensor, False otherwise. + """ + for value in self.weight_map.values(): + if value.endswith(".*.bin") or value.endswith(".*.safetensors"): + return True + return False + + def get_checkpoint_fileanames(self) -> List[str]: + """ + Get the set of checkpoint filenames in the weight map. + + Returns: + list: checkpoint shard filenames. + """ + # read the checkpoint file list from the json file and get a list of unique file names + checkpoint_files = sorted(list(set(self.weight_map.values()))) + + # get the absolute paths for all checkpoint files + checkpoint_files = [str(self.root_path.joinpath(f)) for f in checkpoint_files] + + dtensor_list = [] + checkpoint_list = [] + + for ckpt_file in checkpoint_files: + if is_dtensor_checkpoint(ckpt_file): + dtensor_list.append(ckpt_file) + else: + checkpoint_list.append(ckpt_file) + + return checkpoint_list, dtensor_list + + def assert_no_dtensor_checkpoint(self): + for val in self.weight_map.values(): + if is_dtensor_checkpoint(val): + raise ValueError(f"Checkpoint file {val} contains distributed tensor") + + def get_checkpoint_file(self, param_name: str) -> str: + """ + Get the checkpoint file name for a parameter. + + Args: + param_name (str): name of the parameter. + + Returns: + str: checkpoint file name. + """ + ckpt_path = self.weight_map[param_name] + return ckpt_path diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py new file mode 100644 index 000000000..76c9db0af --- /dev/null +++ b/colossalai/checkpoint_io/utils.py @@ -0,0 +1,278 @@ +from pathlib import Path +from typing import List, Optional, Tuple + +import torch + +# ====================================== +# General helper functions +# ====================================== + + +def calculate_tensor_size(tensor: torch.Tensor) -> float: + """ + Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size. + If so, a new shard should be created. + + Args: + tenosr (torch.Tensor): the tensor to calculate size for. + + Returns: + float: size of the tensor in MB. + """ + return tensor.numel() * tensor.element_size() / 1024 / 1024 + + +def is_safetensors_available() -> bool: + """ + Check whether safetensors is available. + + Returns: + bool: whether safetensors is available. + """ + try: + import safetensors + return True + except ImportError: + return False + + +def is_dtensor_checkpoint(checkpoint_file_path: str) -> bool: + """ + Check whether the checkpoint file is a dtensor checkpoint. + + Args: + checkpoint_file_path (str): path to the checkpoint file. + + Returns: + bool: whether the checkpoint file is a dtensor checkpoint. + """ + if checkpoint_file_path.endswith('.*.safetensors') or checkpoint_file_path.endswith('.*.bin'): + return True + else: + return False + + +def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: + """ + Check whether the checkpoint file is a safetensor checkpoint. + + Args: + checkpoint_file_path (str): path to the checkpoint file. + + Returns: + bool: whether the checkpoint file is a safetensor checkpoint. + """ + if checkpoint_file_path.endswith('.safetensors'): + return True + else: + return False + + +# ====================================== +# Helper functions for saving state dict +# ====================================== + + +def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None: + """ + Save state dict to checkpoint. + + Args: + state_dict (dict): state dict. + checkpoint_file_path (str): path to the checkpoint file. + use_safetensors (bool): whether to use safetensors to save the checkpoint. + """ + if use_safetensors: + assert is_safetensors_available(), "safetensors is not available." + assert checkpoint_file_path.endswith('.safetensors'), \ + "safetensors only supports .safetensors suffix for checkpoint file." + from safetensors.torch import save_file + save_file(state_dict, checkpoint_file_path) + else: + torch.save(state_dict, checkpoint_file_path) + + +def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None: + """ + Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains + only one tensor. + + Args: + tensor (Tensor): tensor to be saved. + index_file (CheckpointIndexFile): path to the checkpoint file. + size_per_shard (int): size per shard in MB. + """ + root_path = index_file.root_path + output_root_path = root_path.joinpath('dtensor') + + # create directory + output_root_path.mkdir(exist_ok=True) + + # save tensor to this directory + # TODO(YuliangLiu): get index of the tensor shard + # e.g. index = + index = 0 + + # save tensor to file + ckpt_file_name = generate_dtensor_file_name(name, index, use_safetensors) + ckpt_file_path = output_root_path.joinpath(ckpt_file_name) + + # dtensor ckpt file always contains only one tensor + state_dict = {name: tensor} + save_state_dict(state_dict, str(ckpt_file_path), use_safetensors) + + # update the weight map + # * means all shards + ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors) + index_file.append_weight_map(name, ckpt_file_name_in_weight_map) + + +def get_checkpoint_file_suffix(use_safetensors: bool) -> str: + """ + Get checkpoint file suffix. + + Args: + use_safetensors (bool): whether to use safetensors to save the checkpoint. + + Returns: + str: checkpoint file suffix. + """ + if use_safetensors: + return '.safetensors' + else: + return '.bin' + + +def generate_checkpoint_shard_file_name(index: int, + total_number: int, + use_safetensors: bool, + prefix: str = None) -> str: + """ + Generate checkpoint shard file name. + + Args: + index (int): index of the shard. + total_number (int): total number of shards. + use_safetensors (bool): whether to use safetensors to save the checkpoint. + prefix (str): prefix of the shard file name. Default: None. + + Returns: + str: checkpoint shard file name. + """ + suffix = get_checkpoint_file_suffix(use_safetensors) + + if prefix is None: + return f"{index:05d}-of-{total_number:05d}.{suffix}" + else: + return f"{prefix}-{index:05d}-of-{total_number:05d}.{suffix}" + + +def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: bool) -> str: + """ + Generate dtensor file name. + + Args: + param_name (str): name of the distributed parameter. + index (int): index of the shard. + use_safetensors (bool): whether to use safetensors to save the checkpoint. + + Returns: + str: dtensor file name. + """ + suffix = get_checkpoint_file_suffix(use_safetensors) + return f'{param_name}.{index}.{suffix}' + + +def save_state_dict_as_shard( + state_dict: dict, + checkpoint_path: str, + index: int, + total_number: int, + use_safetensors: bool, + prefix: str = None, +) -> None: + """ + Save state dict as shard. + + Args: + state_dict (dict): state dict. + checkpoint_path (str): path to the checkpoint file. + index (int): index of the shard. + total_number (int): total number of shards. + prefix (str): prefix of the shard file name. + use_safetensors (bool): whether to use safetensors to save the checkpoint. + """ + # generate the shard name + shard_file_name = generate_checkpoint_shard_file_name(index, total_number, use_safetensors, prefix) + shard_file_path = Path(checkpoint_path).joinpath(shard_file_name).absolute() + + # save the shard + save_state_dict(state_dict, str(shard_file_path), use_safetensors) + + +# ======================================== +# Helper functions for loading state dict +# ======================================== + + +def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: + """ + Check whether the checkpoint has an index file. + + Args: + checkpoint_path (str): path to the checkpoint. + + Returns: + Tuple[bool, Optional[Path]]: a tuple of (has_index_file, index_file_path) + """ + checkpoint_path = Path(checkpoint_path) + if checkpoint_path.is_file(): + # check if it is .index.json + if checkpoint_path.name.endswith('.index.json'): + return True, checkpoint_path + else: + return False, None + elif checkpoint_path.is_dir(): + # check if there is only one a file ending with .index.json in this directory + index_files = list(checkpoint_path.glob('*.index.json')) + + # if we found a .index.json file, make sure there is only one + if len(index_files) > 0: + assert len( + index_files + ) == 1, f'Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}' + + if len(index_files) == 1: + return True, index_files[0] + else: + return False, None + + +def load_state_dict(checkpoint_file_path: Path): + """ + Load state dict from checkpoint. + + Args: + checkpoint_file_path (Path): path to the checkpoint file. + + Returns: + dict: state dict. + """ + + assert not is_dtensor_checkpoint(checkpoint_file_path), \ + f'Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline.' + + if is_safetensor_checkpoint(checkpoint_file_path): + assert is_safetensors_available(), \ + f'Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors.' + # load with safetensors + from safetensors import safe_open + state_dict = {} + with safe_open(checkpoint_file_path, framework="pt", device="cpu") as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + return state_dict + + else: + # load with torch + return torch.load(checkpoint_file_path) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 4e4f35edb..b34dc2e22 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -9,3 +9,4 @@ fabric contexttimer ninja torch>=1.11 +safetensors diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index 2dcc5a5bb..71e8582cc 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -71,6 +71,29 @@ def check_dataloader_sharding(): batch_to_compare), 'Same number was found across ranks but expected it to be different' +def check_checkpoint_save_and_load(): + model_fn, data_gen_fn, output_transform_fn, _ = model_zoo['timm_resnet'] + + plugin = TorchDDPPlugin() + booster = Booster(plugin=plugin) + + model = model_fn() + optimizer = SGD(model.parameters(), lr=1e-3) + criterion = lambda x: x.mean() + data = data_gen_fn() + + data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + + booster.backward(loss, optimizer) + + def run_dist(rank, world_size, port): # init dist env colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index f9f0e03c4..dfbb16af4 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -1,5 +1,6 @@ import tempfile +import pytest import torch from torch.optim import Adam from torchvision.models import resnet18 @@ -14,7 +15,8 @@ from colossalai.checkpoint_io import GeneralCheckpointIO # ======== -def test_unsharded_checkpoint(): +@pytest.mark.parametrize('use_safetensors', [True, False]) +def test_unsharded_checkpoint(use_safetensors: bool): # create a model and optimizer model = resnet18() optimizer = Adam(model.parameters(), lr=0.001) @@ -29,12 +31,16 @@ def test_unsharded_checkpoint(): optimizer.step() # create a temp file for checkpoint - model_ckpt_tempfile = tempfile.NamedTemporaryFile() + if use_safetensors: + suffix = ".safetensors" + else: + suffix = ".bin" + model_ckpt_tempfile = tempfile.NamedTemporaryFile(suffix=suffix) optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() # save the model and optimizer ckpt_io = GeneralCheckpointIO() - ckpt_io.save_model(model, model_ckpt_tempfile.name) + ckpt_io.save_model(model, model_ckpt_tempfile.name, use_safetensors=use_safetensors) ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name) # create new model @@ -68,3 +74,4 @@ def test_unsharded_checkpoint(): # check for model and optimizer state dict recursively recursive_check(model.state_dict(), new_model.state_dict()) recursive_check(optimizer.state_dict(), new_optimizer.state_dict()) + recursive_check(optimizer.state_dict(), new_optimizer.state_dict())