mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 12:12:46 +00:00
[checkpoint] Shard saved checkpoint need to be compatible with the naming format of hf checkpoint files (#3479)
* [checkpoint] support huggingface style sharded checkpoint, to be compatible with hf file naming format * [checkpoint] support huggingface style sharded checkpoint, to be compatible with hf file naming format * [checkpoint] Shard saved checkpoint add 'variant' field to customize filename * [checkpoint] Shard saved checkpoint add 'variant' field to customize filename * [checkpoint] Shard saved checkpoint add 'variant' field to customize filename * [checkpoint] Shard saved checkpoint add 'variant' field to customize filename --------- Co-authored-by: luchen <luchen@luchendeMacBook-Pro.local> Co-authored-by: luchen <luchen@luchendeMBP.lan>
This commit is contained in:
parent
7182ac2a04
commit
366a035552
@ -1,6 +1,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -104,7 +105,7 @@ class CheckpointIO(ABC):
|
|||||||
checkpoint: str,
|
checkpoint: str,
|
||||||
shard: bool = False,
|
shard: bool = False,
|
||||||
gather_dtensor: bool = True,
|
gather_dtensor: bool = True,
|
||||||
prefix: str = None,
|
variant: str = None,
|
||||||
size_per_shard: int = 1024,
|
size_per_shard: int = 1024,
|
||||||
use_safetensors: bool = False):
|
use_safetensors: bool = False):
|
||||||
"""
|
"""
|
||||||
@ -129,7 +130,7 @@ class CheckpointIO(ABC):
|
|||||||
multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure
|
multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure
|
||||||
that the checkpoint path is a directory path instead of a file path.
|
that the checkpoint path is a directory path instead of a file path.
|
||||||
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
|
||||||
prefix (str): prefix for the model checkpoint file name when shard=True. Default: None.
|
variant (str): If specified, weights are saved in the format pytorch_model.<variant>.bin. Default: None.
|
||||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
|
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
|
||||||
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
|
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
|
||||||
"""
|
"""
|
||||||
@ -138,7 +139,7 @@ class CheckpointIO(ABC):
|
|||||||
model = model.unwrap()
|
model = model.unwrap()
|
||||||
|
|
||||||
if shard:
|
if shard:
|
||||||
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
|
self.save_sharded_model(model, checkpoint, gather_dtensor, variant, size_per_shard, use_safetensors)
|
||||||
else:
|
else:
|
||||||
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
||||||
|
|
||||||
@ -219,7 +220,7 @@ class CheckpointIO(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: str,
|
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
|
||||||
size_per_shard: int, use_safetensors: bool):
|
size_per_shard: int, use_safetensors: bool):
|
||||||
"""
|
"""
|
||||||
Save model to sharded checkpoint.
|
Save model to sharded checkpoint.
|
||||||
|
@ -6,6 +6,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import gc
|
import gc
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from .checkpoint_io_base import CheckpointIO
|
from .checkpoint_io_base import CheckpointIO
|
||||||
from .index_file import CheckpointIndexFile
|
from .index_file import CheckpointIndexFile
|
||||||
@ -16,10 +17,12 @@ from .utils import (
|
|||||||
is_safetensors_available,
|
is_safetensors_available,
|
||||||
shard_checkpoint,
|
shard_checkpoint,
|
||||||
load_shard_state_dict,
|
load_shard_state_dict,
|
||||||
load_state_dict_into_model
|
load_state_dict_into_model,
|
||||||
|
add_variant
|
||||||
)
|
)
|
||||||
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME
|
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['GeneralCheckpointIO']
|
__all__ = ['GeneralCheckpointIO']
|
||||||
|
|
||||||
|
|
||||||
@ -69,7 +72,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
|
|
||||||
|
|
||||||
def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False,
|
def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False,
|
||||||
prefix: str = "", max_shard_size: int = 1024, use_safetensors: bool = False):
|
variant: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False):
|
||||||
"""
|
"""
|
||||||
implement this method as it can be supported by Huggingface model,
|
implement this method as it can be supported by Huggingface model,
|
||||||
save shard model, save model to multiple files
|
save shard model, save model to multiple files
|
||||||
@ -83,6 +86,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
# shard checkpoint
|
# shard checkpoint
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
|
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
|
||||||
|
weights_name = add_variant(weights_name, variant)
|
||||||
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
|
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
|
||||||
|
|
||||||
# Save the model
|
# Save the model
|
||||||
@ -92,7 +96,8 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
|
|
||||||
# save index file
|
# save index file
|
||||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
|
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
|
||||||
save_index_file = os.path.join(checkpoint_path, save_index_file)
|
|
||||||
|
save_index_file = os.path.join(checkpoint_path, add_variant(save_index_file, variant))
|
||||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||||
f.write(content)
|
f.write(content)
|
||||||
|
@ -4,11 +4,12 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple
|
from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple
|
||||||
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
||||||
|
import re
|
||||||
|
|
||||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||||
WEIGHTS_NAME = "model.bin"
|
WEIGHTS_NAME = "pytorch_model.bin"
|
||||||
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
|
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
|
||||||
WEIGHTS_INDEX_NAME = "model.bin.index.json"
|
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
||||||
|
|
||||||
# ======================================
|
# ======================================
|
||||||
# General helper functions
|
# General helper functions
|
||||||
@ -27,7 +28,6 @@ def calculate_tensor_size(tensor: torch.Tensor) -> float:
|
|||||||
"""
|
"""
|
||||||
return tensor.numel() * tensor.element_size() / 1024 / 1024
|
return tensor.numel() * tensor.element_size() / 1024 / 1024
|
||||||
|
|
||||||
|
|
||||||
def is_safetensors_available() -> bool:
|
def is_safetensors_available() -> bool:
|
||||||
"""
|
"""
|
||||||
Check whether safetensors is available.
|
Check whether safetensors is available.
|
||||||
@ -358,13 +358,14 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
|
|||||||
checkpoint_path = Path(checkpoint_path)
|
checkpoint_path = Path(checkpoint_path)
|
||||||
if checkpoint_path.is_file():
|
if checkpoint_path.is_file():
|
||||||
# check if it is .index.json
|
# check if it is .index.json
|
||||||
if checkpoint_path.name.endswith('.index.json'):
|
reg = re.compile("(.*?).index((\..*)?).json")
|
||||||
|
if reg.fullmatch(checkpoint_path.name) is not None:
|
||||||
return True, checkpoint_path
|
return True, checkpoint_path
|
||||||
else:
|
else:
|
||||||
return False, None
|
return False, None
|
||||||
elif checkpoint_path.is_dir():
|
elif checkpoint_path.is_dir():
|
||||||
# check if there is only one a file ending with .index.json in this directory
|
# check if there is only one a file ending with .index.json in this directory
|
||||||
index_files = list(checkpoint_path.glob('*.index.json'))
|
index_files = list(checkpoint_path.glob('*.index.*json'))
|
||||||
|
|
||||||
# if we found a .index.json file, make sure there is only one
|
# if we found a .index.json file, make sure there is only one
|
||||||
if len(index_files) > 0:
|
if len(index_files) > 0:
|
||||||
@ -406,3 +407,13 @@ def load_state_dict(checkpoint_file_path: Path):
|
|||||||
else:
|
else:
|
||||||
# load with torch
|
# load with torch
|
||||||
return torch.load(checkpoint_file_path)
|
return torch.load(checkpoint_file_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
||||||
|
if variant is not None and len(variant) > 0:
|
||||||
|
splits = weights_name.split(".")
|
||||||
|
splits = splits[:-1] + [variant] + splits[-1:]
|
||||||
|
weights_name = ".".join(splits)
|
||||||
|
|
||||||
|
return weights_name
|
||||||
|
Loading…
Reference in New Issue
Block a user