Refactored docstring to google style

This commit is contained in:
Liang Bowen
2022-03-25 13:02:39 +08:00
committed by アマデウス
parent 53b1b6e340
commit ec5086c49c
94 changed files with 3389 additions and 2982 deletions

View File

@@ -37,8 +37,8 @@ def get_default_parser():
"""Reads user command line and uses an argument parser to parse the input arguments.
Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
:return: Returns the parser with the default arguments, the user may add customized arguments into this parser
:rtype: Namespace
Returns:
Namespace: Returns the parser with the default arguments, the user may add customized arguments into this parser.
"""
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, help='path to the config file')
@@ -63,26 +63,21 @@ def launch(config: Union[str, Path, Config, Dict],
"""This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input
arguments are not given. Then initialize and set distributed environment by calling global_context's functions.
:param config: Config file or config file path are both acceptable
:type config: Union[str, dict, Config]
:param rank: Rank for the default process group
:type rank: int
:param world_size: World size of the default process group
:type world_size: int
:param host: The master address for distributed training
:type host: str
:param port: The master port for distributed training
:type port: str
:param backend: Backend for torch.distributed
:type backend: str, optional
:param local_rank: Rank for the process on the node and is used to set the default CUDA device, defaults to None.
If local_rank = None, the default device ordinal will be calculated automatically
:type local_rank: int, optional
:param seed: Specified random seed for every processes
:type seed: int, optional
:param verbose: Whether to print logs
:type verbose: bool, optional
:raises Exception: Raise exception when config type is wrong
Args:
config (Union[str, dict, Config]): Config file or config file path are both acceptable
rank (int): Rank for the default process group
world_size (int): World size of the default process group
host (str): The master address for distributed training
port (str): The master port for distributed training
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
local_rank (int, optional):
Rank for the process on the node and is used to set the default CUDA device,
defaults to None. If local_rank = None, the default device ordinal will be calculated automatically.
seed (int, optional): Specified random seed for every process. Defaults to 1024.
verbose (bool, optional): Whether to print logs. Defaults to True.
Raises:
Exception: Raise exception when config type is wrong
"""
gpc.verbose = verbose
@@ -126,18 +121,13 @@ def launch_from_slurm(config: Union[str, Path, Config, Dict],
"""A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables
set by SLURM
:param config: Config file or config file path are both acceptable
:type config: Union[str, dict, Config]
:param host: The master address for distributed training
:type host: str
:param port: The master port for distributed training
:type port: str
:param backend: Backend for torch.distributed
:type backend: str, optional
:param seed: Specified random seed for every processes
:type seed: int, optional
:param verbose: Whether to print logs
:type verbose: bool, optional
Args:
config (Union[str, dict, Config]): Config file or config file path are both acceptable
host (str): The master address for distributed training
port (str): The master port for distributed training
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
seed (int, optional): Specified random seed for every process. Defaults to 1024.
verbose (bool, optional): Whether to print logs. Defaults to True.
"""
rank = int(os.environ['SLURM_PROCID'])
world_size = int(os.environ['SLURM_NPROCS'])
@@ -160,18 +150,13 @@ def launch_from_openmpi(config: Union[str, Path, Config, Dict],
"""A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables
set by OpenMPI
:param config: Config file or config file path are both acceptable
:type config: Union[str, dict, Config]
:param host: The master address for distributed training
:type host: str
:param port: The master port for distributed training
:type port: str
:param backend: Backend for torch.distributed
:type backend: str, optional
:param seed: Specified random seed for every processes
:type seed: int, optional
:param verbose: Whether to print logs
:type verbose: bool, optional
Args:
config (Union[str, dict, Config]): Config file or config file path are both acceptable
host (str): The master address for distributed training
port (str): The master port for distributed training
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
seed (int, optional): Specified random seed for every process. Defaults to 1024.
verbose (bool, optional): Whether to print logs. Defaults to True.
"""
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
@@ -194,14 +179,11 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
"""A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
from the environment variables set by PyTorch
:param config: Config file or config file path are both acceptable
:type config: Union[str, dict, Config]
:param backend: Backend for torch.distributed
:type backend: str, optional
:param seed: Specified random seed for every processes
:type seed: int, optional
:param verbose: Whether to print logs
:type verbose: bool, optional
Args:
config (Union[str, dict, Config]): Config file or config file path are both acceptable
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
seed (int, optional): Specified random seed for every process. Defaults to 1024.
verbose (bool, optional): Whether to print logs. Defaults to True.
"""
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
@@ -230,22 +212,20 @@ def initialize(model: nn.Module,
"""Core function to wrap the essential training components with our functionality based on the config which is
loaded into gpc.config.
:param model: Your model instance or a function to build the model
:type model: :class:`torch.nn.Module` or Callbale
:param optimizer: Your optimizer instance
:type optimizer: :class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`
:param criterion: Your criterion instance
:type criterion: :class:`torch.nn.modules.loss._Loss`, optional
:param train_dataloader: Dataloader for training
:type train_dataloader: :class:`torch.utils.data.DataLoader`, optional
:param test_dataloader: Dataloader for testing
:type test_dataloader: :class:`torch.utils.data.DataLoader`, optional
:param lr_scheduler: Your lr scheduler instance, optional
:type lr_scheduler: :class:`torch.nn.lr_scheduler._LRScheduler`, optional
:param verbose: Whether to print logs
:type verbose: bool, optional
:return: (engine, train_dataloader, test_dataloader, lr_scheduler)
:rtype: Tuple
Args:
model (:class:`torch.nn.Module` or Callbale): Your model instance or a function to build the model.
optimizer (:class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`):
Your optimizer instance.
criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance.
train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training.
test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional.
verbose (bool, optional): Whether to print logs.
Returns:
Tuple (engine, train_dataloader, test_dataloader, lr_scheduler):
A tuple of ``(engine, train_dataloader, test_dataloader, lr_scheduler)``
where only ``engine`` could not be None.
"""
# get logger
logger = get_dist_logger()