mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 07:00:37 +00:00
[doc] add tutorial for booster checkpoint (#3785)
* [doc] add checkpoint related docstr for booster * [doc] add en checkpoint doc * [doc] add zh checkpoint doc * [doc] add booster checkpoint doc in sidebar * [doc] add cuation about ckpt for plugins * [doc] add doctest placeholder * [doc] add doctest placeholder * [doc] add doctest placeholder
This commit is contained in:
@@ -151,6 +151,16 @@ class Booster:
|
||||
return self.plugin.no_sync(model)
|
||||
|
||||
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
|
||||
"""Load model from checkpoint.
|
||||
|
||||
Args:
|
||||
model (nn.Module): A model boosted by Booster.
|
||||
checkpoint (str): Path to the checkpoint. It must be a local path.
|
||||
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
|
||||
strict (bool, optional): whether to strictly enforce that the keys
|
||||
in :attr:`state_dict` match the keys returned by this module's
|
||||
:meth:`~torch.nn.Module.state_dict` function. Defaults to True.
|
||||
"""
|
||||
self.checkpoint_io.load_model(model, checkpoint, strict)
|
||||
|
||||
def save_model(self,
|
||||
@@ -159,16 +169,58 @@ class Booster:
|
||||
prefix: str = None,
|
||||
shard: bool = False,
|
||||
size_per_shard: int = 1024):
|
||||
"""Save model to checkpoint.
|
||||
|
||||
Args:
|
||||
model (nn.Module): A model boosted by Booster.
|
||||
checkpoint (str): Path to the checkpoint. It must be a local path.
|
||||
It is a file path if ``shard=False``. Otherwise, it is a directory path.
|
||||
prefix (str, optional): A prefix added to parameter and buffer
|
||||
names to compose the keys in state_dict. Defaults to None.
|
||||
shard (bool, optional): Whether to save checkpoint a sharded way.
|
||||
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
|
||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||
"""
|
||||
self.checkpoint_io.save_model(model, checkpoint, prefix, shard, size_per_shard)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
"""Load optimizer from checkpoint.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): An optimizer boosted by Booster.
|
||||
checkpoint (str): Path to the checkpoint. It must be a local path.
|
||||
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
|
||||
"""
|
||||
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
|
||||
|
||||
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
|
||||
"""Save optimizer to checkpoint.
|
||||
Warning: Saving sharded optimizer checkpoint is not supported yet.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): An optimizer boosted by Booster.
|
||||
checkpoint (str): Path to the checkpoint. It must be a local path.
|
||||
It is a file path if ``shard=False``. Otherwise, it is a directory path.
|
||||
shard (bool, optional): Whether to save checkpoint a sharded way.
|
||||
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
|
||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||
"""
|
||||
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard)
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
"""Save lr scheduler to checkpoint.
|
||||
|
||||
Args:
|
||||
lr_scheduler (LRScheduler): A lr scheduler boosted by Booster.
|
||||
checkpoint (str): Path to the checkpoint. It must be a local file path.
|
||||
"""
|
||||
self.checkpoint_io.save_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
||||
def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
"""Load lr scheduler from checkpoint.
|
||||
|
||||
Args:
|
||||
lr_scheduler (LRScheduler): A lr scheduler boosted by Booster.
|
||||
checkpoint (str): Path to the checkpoint. It must be a local file path.
|
||||
"""
|
||||
self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
Reference in New Issue
Block a user