diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index c779f4c17..2a76f1718 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -2,37 +2,35 @@ from pathlib import Path import torch.nn as nn from torch.optim import Optimizer +import logging +import os +import json +import gc from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile -from .utils import has_index_file, load_state_dict, save_state_dict +from .utils import ( + has_index_file, + load_state_dict, + save_state_dict, + is_safetensors_available, + shard_checkpoint, + load_shard_state_dict, + load_state_dict_into_model + ) +from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME __all__ = ['GeneralCheckpointIO'] class GeneralCheckpointIO(CheckpointIO): - - def load_sharded_model(self, model: nn.Module, index_file_path: Path, strict: bool): - # load the index file - index_file = CheckpointIndexFile.from_file(index_file_path) - - # iterate over the shard checkpoint files - # and load each - index_file.assert_no_dtensor_checkpoint() - checkpoint_file_list, _ = index_file.get_checkpoint_fileanames() - for shard_file in checkpoint_file_list: - shard_checkpoint = load_state_dict(shard_file) - model.load_state_dict(shard_checkpoint, strict=strict) - + """ + Checkpoint IO + """ def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): checkpoint = load_state_dict(checkpoint) model.load_state_dict(checkpoint, strict=strict) - def save_sharded_model(self, model: nn.Module, checkpoint: Path, gather_dtensor: bool, prefix: str, - size_per_shard: int, use_safetensors: bool): - # TODO(FrankLeeeee): implement this method as it can be supported by Huggingface model - raise NotImplementedError("Sharded model checkpoint is not supported yet.") - def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): state_dict = model.state_dict() @@ -68,3 +66,68 @@ class GeneralCheckpointIO(CheckpointIO): ): # TODO(FrankLeeeee): handle distributed tensors save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False) + + + def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False, + prefix: str = "", max_shard_size: int = 1024, use_safetensors: bool = False): + """ + implement this method as it can be supported by Huggingface model, + save shard model, save model to multiple files + """ + if os.path.isfile(checkpoint_path): + logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") + return + + Path(checkpoint_path).mkdir(parents=True, exist_ok=True) + + # shard checkpoint + state_dict = model.state_dict() + weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME + shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name) + + # Save the model + for shard_file, shard in shards.items(): + checkpoint_file_path = os.path.join(checkpoint_path, shard_file) + save_state_dict(shard, checkpoint_file_path, use_safetensors) + + # save index file + save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME + save_index_file = os.path.join(checkpoint_path, save_index_file) + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logging.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + + def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False): + """ + load shard model, load model from multiple files + """ + use_safetensors = False + if "safetensors" in checkpoint_index_file.name: + use_safetensors = True + + if use_safetensors and not is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") + + # read checkpoint index file + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames() + missing_keys = ckpt_index_file.get_all_param_names() + + for shard_file in checkpoint_files: + state_dict = load_shard_state_dict(Path(shard_file), use_safetensors) + load_state_dict_into_model(model, state_dict, missing_keys, strict) + del state_dict + gc.collect() + + if strict and len(missing_keys) > 0: + error_msgs = 'Missing key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in missing_keys)) + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + self.__class__.__name__, "\n\t".join(error_msgs))) + diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py index 32ff1b762..89224787a 100644 --- a/colossalai/checkpoint_io/index_file.py +++ b/colossalai/checkpoint_io/index_file.py @@ -148,3 +148,9 @@ class CheckpointIndexFile: """ ckpt_path = self.weight_map[param_name] return ckpt_path + + def get_all_param_names(self): + """ + Get all the weight keys. + """ + return list(self.weight_map.keys()) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 76c9db0af..81b666da5 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -1,13 +1,19 @@ +# coding=utf-8 from pathlib import Path -from typing import List, Optional, Tuple - import torch +import torch.nn as nn +from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple +from colossalai.tensor.d_tensor.d_tensor import DTensor + +SAFE_WEIGHTS_NAME = "model.safetensors" +WEIGHTS_NAME = "model.bin" +SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" +WEIGHTS_INDEX_NAME = "model.bin.index.json" # ====================================== # General helper functions # ====================================== - def calculate_tensor_size(tensor: torch.Tensor) -> float: """ Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size. @@ -68,6 +74,130 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: return False +# ====================================== +# Helper functions for saving shard file +# ====================================== +def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, weights_name: str = WEIGHTS_NAME): + + """ + Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a + given size. + """ + sharded_state_dicts = [] + current_block = {} + current_block_size = 0 + total_size = 0 + + for key, weight in state_dict.items(): + if type(weight) != DTensor: + weight_size = calculate_tensor_size(weight) + + # If this weight is going to tip up over the maximal size, we split. + if current_block_size + weight_size > max_shard_size: + sharded_state_dicts.append(current_block) + current_block = {} + current_block_size = 0 + + current_block[key] = weight + current_block_size += weight_size + total_size += weight_size + + # Add the last block + sharded_state_dicts.append(current_block) + + # If we only have one shard, we return it + if len(sharded_state_dicts) == 1: + return {weights_name: sharded_state_dicts[0]}, None + + # Otherwise, let's build the index + weight_map = {} + shards = {} + + for idx, shard in enumerate(sharded_state_dicts): + shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin") + shard_file = shard_file.replace( + ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors" + ) + shards[shard_file] = shard + for key in shard.keys(): + weight_map[key] = shard_file + + # Add the metadata + metadata = {"total_size": total_size} + index = {"metadata": metadata, "weight_map": weight_map} + return shards, index + +def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False): + """ + load shard state dict into model + """ + if use_safetensors and not checkpoint_file.suffix == ".safetensors": + raise Exception("load the model using `safetensors`, but no file endwith .safetensors") + if use_safetensors: + from safetensors.torch import safe_open + from safetensors.torch import load_file as safe_load_file + with safe_open(checkpoint_file, framework="pt") as f: + metadata = f.metadata() + if metadata["format"] != "pt": + raise NotImplementedError( + f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet." + ) + return safe_load_file(checkpoint_file) + else: + return torch.load(checkpoint_file) + +def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False): + r"""Copies parameters and buffers from :attr:`state_dict` into + this module and its descendants. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + """ + if not isinstance(state_dict, Mapping): + raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict))) + + unexpected_keys: List[str] = [] + sub_missing_keys: List[str] = [] + error_msgs: List[str] = [] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = OrderedDict(state_dict) + if metadata is not None: + state_dict._metadata = metadata + + def load(module: nn.Module, state_dict, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) + # Parameters of module and children will start with prefix. We can exit early if there are none in this + # state_dict + if len([key for key in state_dict if key.startswith(prefix)]) > 0: + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, state_dict, prefix + name + ".") + + load(model, state_dict, "") + del load + + # deal with missing key + if len(missing_keys) > 0: + deleted_keys = [] + for key in missing_keys: + if key not in sub_missing_keys: + deleted_keys.append(key) + for key in deleted_keys: + missing_keys.remove(key) + + if strict: + if len(unexpected_keys) > 0: + error_msgs = 'Unexpected key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in unexpected_keys)) + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + model.__class__.__name__, "\n\t".join(error_msgs))) + # ====================================== # Helper functions for saving state dict # ====================================== @@ -86,8 +216,8 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors assert is_safetensors_available(), "safetensors is not available." assert checkpoint_file_path.endswith('.safetensors'), \ "safetensors only supports .safetensors suffix for checkpoint file." - from safetensors.torch import save_file - save_file(state_dict, checkpoint_file_path) + from safetensors.torch import save_file as safe_save_file + safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"}) else: torch.save(state_dict, checkpoint_file_path) diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index 0f78184f7..ca5ce1005 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -1,9 +1,12 @@ import tempfile - import pytest import torch +import logging from torch.optim import Adam from torchvision.models import resnet18 +from pathlib import Path +import os +import subprocess from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.testing import clear_cache_before_run, parameterize @@ -12,7 +15,7 @@ from colossalai.testing import clear_cache_before_run, parameterize # Note: # 1. due to checkpoint IO can be quite slow if tested with all models, we will only test on resnet for now # 2. we will test on both sharded and unsharded checkpoints -# 3. TODO(FrankLeeeee): implement sharded checkpoint and test it +# 3. implement sharded checkpoint and test it # ======== @@ -53,27 +56,71 @@ def test_unsharded_checkpoint(use_safetensors: bool): ckpt_io.load_model(new_model, model_ckpt_tempfile.name) ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) - # do recursive check for the optimizer state dict - # if the value is a dict, compare its values - # if the value is a list, comapre all elements one-by-one - # if the value is a torch.Tensor, use torch.equal - # otherwise use assertEqual - def recursive_check(d1, d2): - for k, v in d1.items(): - if isinstance(v, dict): - recursive_check(v, d2[k]) - elif isinstance(v, list): - for i in range(len(v)): - if isinstance(v[i], torch.Tensor): - assert torch.equal(v[i], d2[k][i]) - else: - assert v[i] == d2[k][i] - elif isinstance(v, torch.Tensor): - assert torch.equal(v, d2[k]) - else: - assert v == d2[k] # check for model and optimizer state dict recursively recursive_check(model.state_dict(), new_model.state_dict()) recursive_check(optimizer.state_dict(), new_optimizer.state_dict()) + +@pytest.mark.parametrize('use_safetensors', [True, False]) +def test_sharded_checkpoint(use_safetensors: bool): + # create a model and optimizer + model = resnet18() + optimizer = Adam(model.parameters(), lr=0.001) + # create test data sample + x = torch.randn(1, 3, 224, 224) + + # run fwd and bwd + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + # create a temp file for checkpoint + if use_safetensors: + suffix = ".safetensors" + SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" + else: + suffix = ".bin" + WEIGHTS_INDEX_NAME = "model.bin.index.json" + + # model_ckpt_dir = tempfile.TemporaryDirectory(suffix=suffix) + model_ckpt_dir = tempfile.TemporaryDirectory() + optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() + + # save the model and optimizer + ckpt_io = GeneralCheckpointIO() + + ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=use_safetensors) + ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name, shard=False) + + # create new model + new_model = resnet18() + new_optimizer = Adam(new_model.parameters(), lr=0.001) + + ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True) + ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) + + # check for model and optimizer state dict recursively + recursive_check(model.state_dict(), new_model.state_dict()) recursive_check(optimizer.state_dict(), new_optimizer.state_dict()) + + +# do recursive check for the optimizer state dict +# if the value is a dict, compare its values +# if the value is a list, comapre all elements one-by-one +# if the value is a torch.Tensor, use torch.equal +# otherwise use assertEqual +def recursive_check(d1, d2): + for k, v in d1.items(): + if isinstance(v, dict): + recursive_check(v, d2[k]) + elif isinstance(v, list): + for i in range(len(v)): + if isinstance(v[i], torch.Tensor): + assert torch.equal(v[i], d2[k][i]) + else: + assert v[i] == d2[k][i] + elif isinstance(v, torch.Tensor): + assert torch.equal(v, d2[k]) + else: + assert v == d2[k]