[checkpoint] Shard saved checkpoint need to be compatible with the naming format of hf checkpoint files (#3479)

* [checkpoint] support huggingface style sharded checkpoint, to be compatible with hf file naming format

* [checkpoint] support huggingface style sharded checkpoint, to be compatible with hf file naming format

* [checkpoint] Shard saved checkpoint add 'variant' field to customize filename

* [checkpoint] Shard saved checkpoint add 'variant' field to customize filename

* [checkpoint] Shard saved checkpoint add 'variant' field to customize filename

* [checkpoint] Shard saved checkpoint add 'variant' field to customize filename

---------

Co-authored-by: luchen <luchen@luchendeMacBook-Pro.local>
Co-authored-by: luchen <luchen@luchendeMBP.lan>
This commit is contained in:
jiangmingyan
2023-04-12 16:02:17 +08:00
committed by GitHub
parent 7182ac2a04
commit 366a035552
3 changed files with 29 additions and 12 deletions

View File

@@ -4,11 +4,12 @@ import torch
import torch.nn as nn
from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple
from colossalai.tensor.d_tensor.d_tensor import DTensor
import re
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "model.bin"
WEIGHTS_NAME = "pytorch_model.bin"
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
WEIGHTS_INDEX_NAME = "model.bin.index.json"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
# ======================================
# General helper functions
@@ -27,7 +28,6 @@ def calculate_tensor_size(tensor: torch.Tensor) -> float:
"""
return tensor.numel() * tensor.element_size() / 1024 / 1024
def is_safetensors_available() -> bool:
"""
Check whether safetensors is available.
@@ -358,13 +358,14 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
checkpoint_path = Path(checkpoint_path)
if checkpoint_path.is_file():
# check if it is .index.json
if checkpoint_path.name.endswith('.index.json'):
reg = re.compile("(.*?).index((\..*)?).json")
if reg.fullmatch(checkpoint_path.name) is not None:
return True, checkpoint_path
else:
return False, None
elif checkpoint_path.is_dir():
# check if there is only one a file ending with .index.json in this directory
index_files = list(checkpoint_path.glob('*.index.json'))
index_files = list(checkpoint_path.glob('*.index.*json'))
# if we found a .index.json file, make sure there is only one
if len(index_files) > 0:
@@ -406,3 +407,13 @@ def load_state_dict(checkpoint_file_path: Path):
else:
# load with torch
return torch.load(checkpoint_file_path)
def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None and len(variant) > 0:
splits = weights_name.split(".")
splits = splits[:-1] + [variant] + splits[-1:]
weights_name = ".".join(splits)
return weights_name