mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[checkpoint] Shard saved checkpoint need to be compatible with the naming format of hf checkpoint files (#3479)
* [checkpoint] support huggingface style sharded checkpoint, to be compatible with hf file naming format * [checkpoint] support huggingface style sharded checkpoint, to be compatible with hf file naming format * [checkpoint] Shard saved checkpoint add 'variant' field to customize filename * [checkpoint] Shard saved checkpoint add 'variant' field to customize filename * [checkpoint] Shard saved checkpoint add 'variant' field to customize filename * [checkpoint] Shard saved checkpoint add 'variant' field to customize filename --------- Co-authored-by: luchen <luchen@luchendeMacBook-Pro.local> Co-authored-by: luchen <luchen@luchendeMBP.lan>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -104,7 +105,7 @@ class CheckpointIO(ABC):
|
||||
checkpoint: str,
|
||||
shard: bool = False,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: str = None,
|
||||
variant: str = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
"""
|
||||
@@ -129,7 +130,7 @@ class CheckpointIO(ABC):
|
||||
multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure
|
||||
that the checkpoint path is a directory path instead of a file path.
|
||||
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
|
||||
prefix (str): prefix for the model checkpoint file name when shard=True. Default: None.
|
||||
variant (str): If specified, weights are saved in the format pytorch_model.<variant>.bin. Default: None.
|
||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
|
||||
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
|
||||
"""
|
||||
@@ -138,7 +139,7 @@ class CheckpointIO(ABC):
|
||||
model = model.unwrap()
|
||||
|
||||
if shard:
|
||||
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
|
||||
self.save_sharded_model(model, checkpoint, gather_dtensor, variant, size_per_shard, use_safetensors)
|
||||
else:
|
||||
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
||||
|
||||
@@ -219,7 +220,7 @@ class CheckpointIO(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: str,
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
|
||||
size_per_shard: int, use_safetensors: bool):
|
||||
"""
|
||||
Save model to sharded checkpoint.
|
||||
|
Reference in New Issue
Block a user