mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[checkpoint] Shard saved checkpoint need to be compatible with the naming format of hf checkpoint files (#3479)
* [checkpoint] support huggingface style sharded checkpoint, to be compatible with hf file naming format * [checkpoint] support huggingface style sharded checkpoint, to be compatible with hf file naming format * [checkpoint] Shard saved checkpoint add 'variant' field to customize filename * [checkpoint] Shard saved checkpoint add 'variant' field to customize filename * [checkpoint] Shard saved checkpoint add 'variant' field to customize filename * [checkpoint] Shard saved checkpoint add 'variant' field to customize filename --------- Co-authored-by: luchen <luchen@luchendeMacBook-Pro.local> Co-authored-by: luchen <luchen@luchendeMBP.lan>
This commit is contained in:
@@ -6,6 +6,7 @@ import logging
|
||||
import os
|
||||
import json
|
||||
import gc
|
||||
from typing import Optional
|
||||
|
||||
from .checkpoint_io_base import CheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
@@ -16,10 +17,12 @@ from .utils import (
|
||||
is_safetensors_available,
|
||||
shard_checkpoint,
|
||||
load_shard_state_dict,
|
||||
load_state_dict_into_model
|
||||
load_state_dict_into_model,
|
||||
add_variant
|
||||
)
|
||||
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME
|
||||
|
||||
|
||||
__all__ = ['GeneralCheckpointIO']
|
||||
|
||||
|
||||
@@ -69,7 +72,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
|
||||
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False,
|
||||
prefix: str = "", max_shard_size: int = 1024, use_safetensors: bool = False):
|
||||
variant: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False):
|
||||
"""
|
||||
implement this method as it can be supported by Huggingface model,
|
||||
save shard model, save model to multiple files
|
||||
@@ -83,6 +86,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
# shard checkpoint
|
||||
state_dict = model.state_dict()
|
||||
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
|
||||
weights_name = add_variant(weights_name, variant)
|
||||
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
|
||||
|
||||
# Save the model
|
||||
@@ -92,7 +96,8 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
|
||||
# save index file
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
|
||||
save_index_file = os.path.join(checkpoint_path, save_index_file)
|
||||
|
||||
save_index_file = os.path.join(checkpoint_path, add_variant(save_index_file, variant))
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
|
Reference in New Issue
Block a user