mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[checkpointio] General Checkpointing of Sharded Optimizers (#3984)
This commit is contained in:
@@ -103,7 +103,7 @@ class CheckpointIO(ABC):
|
||||
checkpoint: str,
|
||||
shard: bool = False,
|
||||
gather_dtensor: bool = True,
|
||||
variant: str = None,
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
"""
|
||||
@@ -128,7 +128,7 @@ class CheckpointIO(ABC):
|
||||
multiple files. The model shards will be specified by a `model.index.json` file. When shard = True, please ensure
|
||||
that the checkpoint path is a directory path instead of a file path.
|
||||
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
|
||||
variant (str): If specified, weights are saved in the format pytorch_model.<variant>.bin. Default: None.
|
||||
prefix (str): If specified, weights are saved in the format pytorch_model.<prefix>.bin. Default: None.
|
||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
|
||||
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
|
||||
"""
|
||||
@@ -137,11 +137,11 @@ class CheckpointIO(ABC):
|
||||
model = model.unwrap()
|
||||
|
||||
if shard:
|
||||
self.save_sharded_model(model, checkpoint, gather_dtensor, variant, size_per_shard, use_safetensors)
|
||||
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
|
||||
else:
|
||||
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
|
||||
"""
|
||||
Load optimizer from checkpoint.
|
||||
|
||||
@@ -157,7 +157,7 @@ class CheckpointIO(ABC):
|
||||
|
||||
if index_file_exists:
|
||||
# the existence of index file means it is a sharded checkpoint
|
||||
self.load_sharded_optimizer(optimizer, index_file_path)
|
||||
self.load_sharded_optimizer(optimizer, index_file_path, prefix, size_per_shard)
|
||||
else:
|
||||
self.load_unsharded_optimizer(optimizer, checkpoint)
|
||||
|
||||
@@ -218,7 +218,7 @@ class CheckpointIO(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
|
||||
size_per_shard: int, use_safetensors: bool):
|
||||
"""
|
||||
Save model to sharded checkpoint.
|
||||
|
Reference in New Issue
Block a user