Fixed docstring in colossalai (#171)

This commit is contained in:
HELSON
2022-01-21 10:44:30 +08:00
committed by GitHub
parent e2089c5c15
commit 0f8c7f9804
77 changed files with 983 additions and 603 deletions

View File

@@ -19,9 +19,8 @@ __all__ = [
def unwrap_config(config: Config):
'''
unwrap Config objects to normal dicts
'''
"""Unwrap Config objects to normal dicts
"""
config_dict = dict()
for k, v in config.items():
if isinstance(v, dict):
@@ -53,18 +52,18 @@ def _get_standard_checkpoint_filename(epoch: int, suffix: str = ''):
def get_checkpoint_path(checkpoint_dir: str, epoch: int, suffix: str = ''):
'''This is a function to generate the checkpoint path from the (checkpoint_dir, epoch, suffix, gpu_parallel_rank) tuple.
"""This is a function to generate the checkpoint path from the (checkpoint_dir, epoch, suffix, gpu_parallel_rank) tuple.
This is useful during generation and recuperation of the checkpoint.
:param checkpoint_dir: set up a directory for saving checkpoints
:param checkpoint_dir: Set up a directory for saving checkpoints
:type checkpoint_dir: str
:param epoch: epoch number (indicate how many epochs have you trained this model)
:param epoch: Epoch number (indicate how many epochs have you trained this model)
:type epoch: int
:param suffix: additional notation to specify the model or checkpoint, defaults to ''
:param suffix: Additional notation to specify the model or checkpoint, defaults to ''
:type suffix: str, optional
:return: checkpoint path to be generated
:return: Checkpoint path to be generated
:rtype: path
'''
"""
ckpt_filename = _get_standard_checkpoint_filename(epoch, suffix)
return os.path.join(checkpoint_dir, ckpt_filename)
@@ -77,30 +76,30 @@ def _ensure_directory_exists(filename: str):
def get_latest_checkpoint_pattern(suffix: str = ''):
'''Generate Regular expression of latest checkpoint's pattern
"""Generate Regular expression of latest checkpoint's pattern
:param suffix: additional notation to specify the model or checkpoint, defaults to ''
:param suffix: Additional notation to specify the model or checkpoint, defaults to ''
:type suffix: str, optional
:return: checkpoint pattern
:return: Checkpoint pattern
:rtype: regular expression
'''
"""
ranks_name = _get_ranks_name()
ckpt_pattern = re.compile(f'epoch(\d+)-{ranks_name}{suffix}\.pt')
return ckpt_pattern
def get_latest_checkpoint_path(checkpoint_dir: str, suffix: str = ''):
'''This is a function to retrieve the latest checkpoint path from the (checkpoint_dir, suffix, gpu_parallel_rank) tuple.
"""This is a function to retrieve the latest checkpoint path from the (checkpoint_dir, suffix, gpu_parallel_rank) tuple.
This is useful during recuperation of the checkpoint, especially when you do not know the epoch number.
:param checkpoint_dir: directory for saving checkpoints
:param checkpoint_dir: Directory for saving checkpoints
:type checkpoint_dir: str
:param suffix: additional notation to specify the model or checkpoint, defaults to ''
:param suffix: Additional notation to specify the model or checkpoint, defaults to ''
:type suffix: str, optional
:raises FileNotFoundError: raise error when we cannot find the latest checkpoint file with inputs given
:return: the latest checkpoint path to be retrieved
:raises FileNotFoundError: Raise error when we cannot find the latest checkpoint file with inputs given
:return: The latest checkpoint path to be retrieved
:rtype: path
'''
"""
CKPT_NAME_PAT = get_latest_checkpoint_pattern(suffix=suffix)
last_epoch = -1
@@ -128,22 +127,22 @@ def save_checkpoint(checkpoint_path: str,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
**kwargs):
'''Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as model, optimizer, lr_scheduler and etc. into a checkpoint dictionary.
"""Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as model, optimizer, lr_scheduler and etc. into a checkpoint dictionary.
This method can be used for both colosalai nn.BaseModel and normal pytorch nn.Module.
:param checkpoint_path: set up a directory for saving checkpoints
:param checkpoint_path: Set up a directory for saving checkpoints
:type checkpoint_path: str
:param epoch: epoch number (indicate how many epochs have you trained this model)
:param epoch: Epoch number (indicate how many epochs have you trained this model)
:type epoch: int
:param model: model to be registered
:param model: Model to be registered
:type model: torch.nn.Module
:param optimizer: optimizer to be registered
:param optimizer: Optimizer to be registered
:type optimizer: torch.optim.Optimizer
:param lr_scheduler: lr_scheduler to be registered, defaults to None
:type lr_scheduler: torch.optim.lr_scheduler._LRScheduler, optional
'''
"""
# for compatibility with normal pytorch nn.Module
if hasattr(model, 'state_dict_for_save_checkpoint'):
model_sd = model.state_dict_for_save_checkpoint()
@@ -170,31 +169,31 @@ def load_checkpoint(checkpoint_path: str,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
finetune: bool = False,
strict: bool = True) -> Tuple:
'''Loads the checkpoint file.
If finetune is False, then we intend to continue/resume the training process from the checkpoint given.
So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler) and its descendants.
"""Loads the checkpoint file.
If finetune is False, then we intend to continue/resume the training process from the checkpoint given.
So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler) and its descendants.
If finetune is True, then only the weights and buffers of model should be reload.
If strict is True, then the keys of state_dict must exactly match the keys returned by this modules state_dict() function.
:param checkpoint_path: the exact and matched checkpoint_path directory to retrieve appropriate state_dict
:param checkpoint_path: The exact and matched checkpoint_path directory to retrieve appropriate state_dict
:type checkpoint_path: str
:param model: model to reload parameters and buffers
:param model: Model to reload parameters and buffers
:type model: torch.nn.Module
:param optimizer: optimizer to recuperate
:type optimizer: torch.optim.Optimizer
:param optimizer: Optimizer to recuperate
:type optimizer: torch.optim.Optimizer
:param lr_scheduler: lr_scheduler to recuperate, defaults to None
:type lr_scheduler: torch.optim.lr_scheduler._LRScheduler, optional
:param finetune: whether to finetune the model with new dataset or continue the pre-training, defaults to False
:param finetune: Whether to finetune the model with new dataset or continue the pre-training, defaults to False
:type finetune: bool, optional
:param strict: whether to strictly enforce that the keys in
:param strict: Whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of
parameters and buffers in model., defaults to True
:type strict: bool, optional
:raises ValueError: raise error if the model/optimizer cannot successfully be recuperated
:raises ValueError: Raise error if the model/optimizer cannot successfully be recuperated
:return: (the epoch number of the checkpoint retrieved, the checkpoint retrieved)
:rtype: Tuple
'''
"""
# Load the checkpoint.
checkpoint = torch.load(checkpoint_path, map_location='cpu')
try: