[checkpointio] General Checkpointing of Sharded Optimizers (#3984)

This commit is contained in:
Baizhou Zhang
2023-06-15 15:21:26 +08:00
committed by GitHub
parent 8bcad73677
commit c9cff7e7fa
8 changed files with 399 additions and 38 deletions

View File

@@ -103,7 +103,7 @@ class CheckpointIO(ABC):
checkpoint: str,
shard: bool = False,
gather_dtensor: bool = True,
variant: str = None,
prefix: str = None,
size_per_shard: int = 1024,
use_safetensors: bool = False):
"""
@@ -128,7 +128,7 @@ class CheckpointIO(ABC):
multiple files. The model shards will be specified by a `model.index.json` file. When shard = True, please ensure
that the checkpoint path is a directory path instead of a file path.
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
variant (str): If specified, weights are saved in the format pytorch_model.<variant>.bin. Default: None.
prefix (str): If specified, weights are saved in the format pytorch_model.<prefix>.bin. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
"""
@@ -137,11 +137,11 @@ class CheckpointIO(ABC):
model = model.unwrap()
if shard:
self.save_sharded_model(model, checkpoint, gather_dtensor, variant, size_per_shard, use_safetensors)
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
else:
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
"""
Load optimizer from checkpoint.
@@ -157,7 +157,7 @@ class CheckpointIO(ABC):
if index_file_exists:
# the existence of index file means it is a sharded checkpoint
self.load_sharded_optimizer(optimizer, index_file_path)
self.load_sharded_optimizer(optimizer, index_file_path, prefix, size_per_shard)
else:
self.load_unsharded_optimizer(optimizer, checkpoint)
@@ -218,7 +218,7 @@ class CheckpointIO(ABC):
pass
@abstractmethod
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
size_per_shard: int, use_safetensors: bool):
"""
Save model to sharded checkpoint.