diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index b91b00831..3f8b0b0a6 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from pathlib import Path from typing import Union +from typing import Optional import torch import torch.nn as nn @@ -104,7 +105,7 @@ class CheckpointIO(ABC): checkpoint: str, shard: bool = False, gather_dtensor: bool = True, - prefix: str = None, + variant: str = None, size_per_shard: int = 1024, use_safetensors: bool = False): """ @@ -129,7 +130,7 @@ class CheckpointIO(ABC): multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure that the checkpoint path is a directory path instead of a file path. gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True. - prefix (str): prefix for the model checkpoint file name when shard=True. Default: None. + variant (str): If specified, weights are saved in the format pytorch_model..bin. Default: None. size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True. use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved """ @@ -138,7 +139,7 @@ class CheckpointIO(ABC): model = model.unwrap() if shard: - self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors) + self.save_sharded_model(model, checkpoint, gather_dtensor, variant, size_per_shard, use_safetensors) else: self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) @@ -219,7 +220,7 @@ class CheckpointIO(ABC): pass @abstractmethod - def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: str, + def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str], size_per_shard: int, use_safetensors: bool): """ Save model to sharded checkpoint. diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 2a76f1718..bf584f45d 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -6,6 +6,7 @@ import logging import os import json import gc +from typing import Optional from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile @@ -16,10 +17,12 @@ from .utils import ( is_safetensors_available, shard_checkpoint, load_shard_state_dict, - load_state_dict_into_model + load_state_dict_into_model, + add_variant ) from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME + __all__ = ['GeneralCheckpointIO'] @@ -69,7 +72,7 @@ class GeneralCheckpointIO(CheckpointIO): def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False, - prefix: str = "", max_shard_size: int = 1024, use_safetensors: bool = False): + variant: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False): """ implement this method as it can be supported by Huggingface model, save shard model, save model to multiple files @@ -83,6 +86,7 @@ class GeneralCheckpointIO(CheckpointIO): # shard checkpoint state_dict = model.state_dict() weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME + weights_name = add_variant(weights_name, variant) shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name) # Save the model @@ -92,7 +96,8 @@ class GeneralCheckpointIO(CheckpointIO): # save index file save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME - save_index_file = os.path.join(checkpoint_path, save_index_file) + + save_index_file = os.path.join(checkpoint_path, add_variant(save_index_file, variant)) with open(save_index_file, "w", encoding="utf-8") as f: content = json.dumps(index, indent=2, sort_keys=True) + "\n" f.write(content) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 81b666da5..37d22d08d 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -4,11 +4,12 @@ import torch import torch.nn as nn from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple from colossalai.tensor.d_tensor.d_tensor import DTensor +import re SAFE_WEIGHTS_NAME = "model.safetensors" -WEIGHTS_NAME = "model.bin" +WEIGHTS_NAME = "pytorch_model.bin" SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" -WEIGHTS_INDEX_NAME = "model.bin.index.json" +WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" # ====================================== # General helper functions @@ -27,7 +28,6 @@ def calculate_tensor_size(tensor: torch.Tensor) -> float: """ return tensor.numel() * tensor.element_size() / 1024 / 1024 - def is_safetensors_available() -> bool: """ Check whether safetensors is available. @@ -358,13 +358,14 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: checkpoint_path = Path(checkpoint_path) if checkpoint_path.is_file(): # check if it is .index.json - if checkpoint_path.name.endswith('.index.json'): + reg = re.compile("(.*?).index((\..*)?).json") + if reg.fullmatch(checkpoint_path.name) is not None: return True, checkpoint_path else: return False, None elif checkpoint_path.is_dir(): # check if there is only one a file ending with .index.json in this directory - index_files = list(checkpoint_path.glob('*.index.json')) + index_files = list(checkpoint_path.glob('*.index.*json')) # if we found a .index.json file, make sure there is only one if len(index_files) > 0: @@ -406,3 +407,13 @@ def load_state_dict(checkpoint_file_path: Path): else: # load with torch return torch.load(checkpoint_file_path) + + + +def add_variant(weights_name: str, variant: Optional[str] = None) -> str: + if variant is not None and len(variant) > 0: + splits = weights_name.split(".") + splits = splits[:-1] + [variant] + splits[-1:] + weights_name = ".".join(splits) + + return weights_name