[checkpoint] Shard saved checkpoint need to be compatible with the naming format of hf checkpoint files (#3479)

* [checkpoint] support huggingface style sharded checkpoint, to be compatible with hf file naming format

* [checkpoint] support huggingface style sharded checkpoint, to be compatible with hf file naming format

* [checkpoint] Shard saved checkpoint add 'variant' field to customize filename

* [checkpoint] Shard saved checkpoint add 'variant' field to customize filename

* [checkpoint] Shard saved checkpoint add 'variant' field to customize filename

* [checkpoint] Shard saved checkpoint add 'variant' field to customize filename

---------

Co-authored-by: luchen <luchen@luchendeMacBook-Pro.local>
Co-authored-by: luchen <luchen@luchendeMBP.lan>
This commit is contained in:
jiangmingyan 2023-04-12 16:02:17 +08:00 committed by GitHub
parent 7182ac2a04
commit 366a035552
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 12 deletions

View File

@ -1,6 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -104,7 +105,7 @@ class CheckpointIO(ABC):
checkpoint: str, checkpoint: str,
shard: bool = False, shard: bool = False,
gather_dtensor: bool = True, gather_dtensor: bool = True,
prefix: str = None, variant: str = None,
size_per_shard: int = 1024, size_per_shard: int = 1024,
use_safetensors: bool = False): use_safetensors: bool = False):
""" """
@ -129,7 +130,7 @@ class CheckpointIO(ABC):
multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure
that the checkpoint path is a directory path instead of a file path. that the checkpoint path is a directory path instead of a file path.
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True. gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
prefix (str): prefix for the model checkpoint file name when shard=True. Default: None. variant (str): If specified, weights are saved in the format pytorch_model.<variant>.bin. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True. size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
""" """
@ -138,7 +139,7 @@ class CheckpointIO(ABC):
model = model.unwrap() model = model.unwrap()
if shard: if shard:
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors) self.save_sharded_model(model, checkpoint, gather_dtensor, variant, size_per_shard, use_safetensors)
else: else:
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
@ -219,7 +220,7 @@ class CheckpointIO(ABC):
pass pass
@abstractmethod @abstractmethod
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: str, def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
size_per_shard: int, use_safetensors: bool): size_per_shard: int, use_safetensors: bool):
""" """
Save model to sharded checkpoint. Save model to sharded checkpoint.

View File

@ -6,6 +6,7 @@ import logging
import os import os
import json import json
import gc import gc
from typing import Optional
from .checkpoint_io_base import CheckpointIO from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile from .index_file import CheckpointIndexFile
@ -16,10 +17,12 @@ from .utils import (
is_safetensors_available, is_safetensors_available,
shard_checkpoint, shard_checkpoint,
load_shard_state_dict, load_shard_state_dict,
load_state_dict_into_model load_state_dict_into_model,
add_variant
) )
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME
__all__ = ['GeneralCheckpointIO'] __all__ = ['GeneralCheckpointIO']
@ -69,7 +72,7 @@ class GeneralCheckpointIO(CheckpointIO):
def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False, def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False,
prefix: str = "", max_shard_size: int = 1024, use_safetensors: bool = False): variant: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False):
""" """
implement this method as it can be supported by Huggingface model, implement this method as it can be supported by Huggingface model,
save shard model, save model to multiple files save shard model, save model to multiple files
@ -83,6 +86,7 @@ class GeneralCheckpointIO(CheckpointIO):
# shard checkpoint # shard checkpoint
state_dict = model.state_dict() state_dict = model.state_dict()
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
weights_name = add_variant(weights_name, variant)
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name) shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
# Save the model # Save the model
@ -92,7 +96,8 @@ class GeneralCheckpointIO(CheckpointIO):
# save index file # save index file
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(checkpoint_path, save_index_file)
save_index_file = os.path.join(checkpoint_path, add_variant(save_index_file, variant))
with open(save_index_file, "w", encoding="utf-8") as f: with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n" content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content) f.write(content)

View File

@ -4,11 +4,12 @@ import torch
import torch.nn as nn import torch.nn as nn
from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple
from colossalai.tensor.d_tensor.d_tensor import DTensor from colossalai.tensor.d_tensor.d_tensor import DTensor
import re
SAFE_WEIGHTS_NAME = "model.safetensors" SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "model.bin" WEIGHTS_NAME = "pytorch_model.bin"
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
WEIGHTS_INDEX_NAME = "model.bin.index.json" WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
# ====================================== # ======================================
# General helper functions # General helper functions
@ -27,7 +28,6 @@ def calculate_tensor_size(tensor: torch.Tensor) -> float:
""" """
return tensor.numel() * tensor.element_size() / 1024 / 1024 return tensor.numel() * tensor.element_size() / 1024 / 1024
def is_safetensors_available() -> bool: def is_safetensors_available() -> bool:
""" """
Check whether safetensors is available. Check whether safetensors is available.
@ -358,13 +358,14 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
checkpoint_path = Path(checkpoint_path) checkpoint_path = Path(checkpoint_path)
if checkpoint_path.is_file(): if checkpoint_path.is_file():
# check if it is .index.json # check if it is .index.json
if checkpoint_path.name.endswith('.index.json'): reg = re.compile("(.*?).index((\..*)?).json")
if reg.fullmatch(checkpoint_path.name) is not None:
return True, checkpoint_path return True, checkpoint_path
else: else:
return False, None return False, None
elif checkpoint_path.is_dir(): elif checkpoint_path.is_dir():
# check if there is only one a file ending with .index.json in this directory # check if there is only one a file ending with .index.json in this directory
index_files = list(checkpoint_path.glob('*.index.json')) index_files = list(checkpoint_path.glob('*.index.*json'))
# if we found a .index.json file, make sure there is only one # if we found a .index.json file, make sure there is only one
if len(index_files) > 0: if len(index_files) > 0:
@ -406,3 +407,13 @@ def load_state_dict(checkpoint_file_path: Path):
else: else:
# load with torch # load with torch
return torch.load(checkpoint_file_path) return torch.load(checkpoint_file_path)
def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None and len(variant) > 0:
splits = weights_name.split(".")
splits = splits[:-1] + [variant] + splits[-1:]
weights_name = ".".join(splits)
return weights_name