mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 17:40:33 +00:00
Refactored docstring to google style
This commit is contained in:
@@ -50,17 +50,17 @@ def _get_standard_checkpoint_filename(epoch: int, suffix: str = ''):
|
||||
|
||||
|
||||
def get_checkpoint_path(checkpoint_dir: str, epoch: int, suffix: str = ''):
|
||||
"""This is a function to generate the checkpoint path from the (checkpoint_dir, epoch, suffix, gpu_parallel_rank) tuple.
|
||||
"""This is a function to generate the checkpoint path from the tuple
|
||||
(checkpoint_dir, epoch, suffix, gpu_parallel_rank).
|
||||
This is useful during generation and recuperation of the checkpoint.
|
||||
|
||||
:param checkpoint_dir: Set up a directory for saving checkpoints
|
||||
:type checkpoint_dir: str
|
||||
:param epoch: Epoch number (indicate how many epochs have you trained this model)
|
||||
:type epoch: int
|
||||
:param suffix: Additional notation to specify the model or checkpoint, defaults to ''
|
||||
:type suffix: str, optional
|
||||
:return: Checkpoint path to be generated
|
||||
:rtype: path
|
||||
Args:
|
||||
checkpoint_dir (str): Set up a directory for saving checkpoints.
|
||||
epoch (int): Epoch number (indicate how many epochs have you trained this model).
|
||||
suffix (str, optional): Additional notation to specify the model or checkpoint, defaults to ''
|
||||
|
||||
Returns:
|
||||
str: The checkpoint path to be generated.
|
||||
"""
|
||||
ckpt_filename = _get_standard_checkpoint_filename(epoch, suffix)
|
||||
return os.path.join(checkpoint_dir, ckpt_filename)
|
||||
@@ -74,12 +74,13 @@ def _ensure_directory_exists(filename: str):
|
||||
|
||||
|
||||
def get_latest_checkpoint_pattern(suffix: str = ''):
|
||||
"""Generate Regular expression of latest checkpoint's pattern
|
||||
"""Generate Regular expression of the latest checkpoint's pattern.
|
||||
|
||||
:param suffix: Additional notation to specify the model or checkpoint, defaults to ''
|
||||
:type suffix: str, optional
|
||||
:return: Checkpoint pattern
|
||||
:rtype: regular expression
|
||||
Args:
|
||||
suffix (str, optional): Additional notation to specify the model or checkpoint, defaults to ''.
|
||||
|
||||
Returns:
|
||||
str: The regular expression of checkpoint pattern.
|
||||
"""
|
||||
ranks_name = _get_ranks_name()
|
||||
pattern = r'epoch(\d+)-{}{}\.pt'.format(ranks_name, suffix)
|
||||
@@ -88,16 +89,19 @@ def get_latest_checkpoint_pattern(suffix: str = ''):
|
||||
|
||||
|
||||
def get_latest_checkpoint_path(checkpoint_dir: str, suffix: str = ''):
|
||||
"""This is a function to retrieve the latest checkpoint path from the (checkpoint_dir, suffix, gpu_parallel_rank) tuple.
|
||||
"""This is a function to retrieve the latest checkpoint path from the tuple
|
||||
(checkpoint_dir, suffix, gpu_parallel_rank).
|
||||
This is useful during recuperation of the checkpoint, especially when you do not know the epoch number.
|
||||
|
||||
:param checkpoint_dir: Directory for saving checkpoints
|
||||
:type checkpoint_dir: str
|
||||
:param suffix: Additional notation to specify the model or checkpoint, defaults to ''
|
||||
:type suffix: str, optional
|
||||
:raises FileNotFoundError: Raise error when we cannot find the latest checkpoint file with inputs given
|
||||
:return: The latest checkpoint path to be retrieved
|
||||
:rtype: path
|
||||
Args:
|
||||
checkpoint_dir (str): Directory for saving checkpoints
|
||||
suffix (str, optional): Additional notation to specify the model or checkpoint, defaults to ''
|
||||
|
||||
Returns:
|
||||
str: The latest retrieved checkpoint path.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: Raise error when we cannot find the latest checkpoint file with inputs given.
|
||||
"""
|
||||
CKPT_NAME_PAT = get_latest_checkpoint_pattern(suffix=suffix)
|
||||
|
||||
@@ -126,22 +130,19 @@ def save_checkpoint(checkpoint_path: str,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
**kwargs):
|
||||
"""Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as model,
|
||||
optimizer, lr_scheduler and etc. into a checkpoint dictionary.
|
||||
"""Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as
|
||||
model, optimizer, lr_scheduler etc. into a checkpoint dictionary.
|
||||
|
||||
This method can be used for both colosalai nn.BaseModel and normal pytorch nn.Module.
|
||||
This method can be used for both :class:`colossalai.nn.BaseModel` and normal :class:`torch.nn.Module`.
|
||||
|
||||
|
||||
:param checkpoint_path: Set up a directory for saving checkpoints
|
||||
:type checkpoint_path: str
|
||||
:param epoch: Epoch number (indicate how many epochs have you trained this model)
|
||||
:type epoch: int
|
||||
:param model: Model to be registered
|
||||
:type model: torch.nn.Module
|
||||
:param optimizer: Optimizer to be registered
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param lr_scheduler: lr_scheduler to be registered, defaults to None
|
||||
:type lr_scheduler: torch.optim.lr_scheduler._LRScheduler, optional
|
||||
Args:
|
||||
checkpoint_path (str): Set up a directory for saving checkpoints.
|
||||
epoch (int): Epoch number (indicate how many epochs have you trained this model).
|
||||
model (:class:`torch.nn.Module`): Model to be registered.
|
||||
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be registered.
|
||||
lr_scheduler (Union[:class:`torch.optim.lr_scheduler`,
|
||||
:class:`colossalai.nn.lr_scheduler`], optional): lr_scheduler to be registered, defaults to None.
|
||||
kwargs (dict): additional parameters to be saved.
|
||||
"""
|
||||
# for compatibility with normal pytorch nn.Module
|
||||
if hasattr(model, 'state_dict_for_save_checkpoint'):
|
||||
@@ -165,31 +166,31 @@ def load_checkpoint(checkpoint_path: str,
|
||||
finetune: bool = False,
|
||||
strict: bool = True) -> Tuple:
|
||||
"""Loads the checkpoint file.
|
||||
|
||||
If finetune is False, then we intend to continue/resume the training process from the checkpoint given.
|
||||
So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler)
|
||||
and its descendants.
|
||||
If finetune is True, then only the weights and buffers of model should be reload.
|
||||
If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s
|
||||
state_dict() function.
|
||||
and its descendants.
|
||||
|
||||
:param checkpoint_path: The exact and matched checkpoint_path directory to retrieve appropriate state_dict
|
||||
:type checkpoint_path: str
|
||||
:param model: Model to reload parameters and buffers
|
||||
:type model: torch.nn.Module
|
||||
:param optimizer: Optimizer to recuperate
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param lr_scheduler: lr_scheduler to recuperate, defaults to None
|
||||
:type lr_scheduler: torch.optim.lr_scheduler._LRScheduler, optional
|
||||
:param finetune: Whether to finetune the model with new dataset or continue the pre-training, defaults to False
|
||||
:type finetune: bool, optional
|
||||
:param strict: Whether to strictly enforce that the keys in
|
||||
:attr:`state_dict` of the checkpoint match the names of
|
||||
parameters and buffers in model., defaults to True
|
||||
:type strict: bool, optional
|
||||
:raises ValueError: Raise error if the model/optimizer cannot successfully be recuperated
|
||||
:return: (the epoch number of the checkpoint retrieved, the checkpoint retrieved)
|
||||
:rtype: Tuple
|
||||
If finetune is True, then only the weights and buffers of model should be reloaded.
|
||||
If strict is True, then the keys of state_dict must exactly match the keys returned
|
||||
by this module’s state_dict() function.
|
||||
|
||||
Args:
|
||||
checkpoint_path (str): The exact and matched checkpoint_path directory to retrieve appropriate state_dict.
|
||||
model (:class:`torch.nn.Module`): Model to reload parameters and buffers.
|
||||
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to recuperate.
|
||||
lr_scheduler (:class:`torch.optim.lr_scheduler._LRScheduler`, optional):
|
||||
lr_scheduler to recuperate, defaults to None.
|
||||
finetune (bool, optional): Whether to finetune the model with new dataset or
|
||||
continue the pre-training, defaults to False.
|
||||
strict (bool, optional): Whether to strictly enforce that the keys in :attr:`state_dict`
|
||||
of the checkpoint match the names of parameters and buffers in model, defaults to True.
|
||||
|
||||
Returns:
|
||||
Tuple(int, ``checkpoint``): The tuple (the epoch number of the checkpoint retrieved, the checkpoint retrieved).
|
||||
|
||||
Raises:
|
||||
ValueError: Raise error if the model/optimizer cannot successfully be recuperated
|
||||
"""
|
||||
# Load the checkpoint.
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
|
Reference in New Issue
Block a user