[checkpoint] Shard saved checkpoint need to be compatible with the naming format of hf checkpoint files (#3479)

* [checkpoint] support huggingface style sharded checkpoint, to be compatible with hf file naming format

* [checkpoint] support huggingface style sharded checkpoint, to be compatible with hf file naming format

* [checkpoint] Shard saved checkpoint add 'variant' field to customize filename

* [checkpoint] Shard saved checkpoint add 'variant' field to customize filename

* [checkpoint] Shard saved checkpoint add 'variant' field to customize filename

* [checkpoint] Shard saved checkpoint add 'variant' field to customize filename

---------

Co-authored-by: luchen <luchen@luchendeMacBook-Pro.local>
Co-authored-by: luchen <luchen@luchendeMBP.lan>
This commit is contained in:
jiangmingyan
2023-04-12 16:02:17 +08:00
committed by GitHub
parent 7182ac2a04
commit 366a035552
3 changed files with 29 additions and 12 deletions

View File

@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Union
from typing import Optional
import torch
import torch.nn as nn
@@ -104,7 +105,7 @@ class CheckpointIO(ABC):
checkpoint: str,
shard: bool = False,
gather_dtensor: bool = True,
prefix: str = None,
variant: str = None,
size_per_shard: int = 1024,
use_safetensors: bool = False):
"""
@@ -129,7 +130,7 @@ class CheckpointIO(ABC):
multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure
that the checkpoint path is a directory path instead of a file path.
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
prefix (str): prefix for the model checkpoint file name when shard=True. Default: None.
variant (str): If specified, weights are saved in the format pytorch_model.<variant>.bin. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
"""
@@ -138,7 +139,7 @@ class CheckpointIO(ABC):
model = model.unwrap()
if shard:
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
self.save_sharded_model(model, checkpoint, gather_dtensor, variant, size_per_shard, use_safetensors)
else:
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
@@ -219,7 +220,7 @@ class CheckpointIO(ABC):
pass
@abstractmethod
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: str,
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
size_per_shard: int, use_safetensors: bool):
"""
Save model to sharded checkpoint.