mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
Refactored docstring to google style
This commit is contained in:
@@ -37,8 +37,8 @@ def get_default_parser():
|
||||
"""Reads user command line and uses an argument parser to parse the input arguments.
|
||||
Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
|
||||
|
||||
:return: Returns the parser with the default arguments, the user may add customized arguments into this parser
|
||||
:rtype: Namespace
|
||||
Returns:
|
||||
Namespace: Returns the parser with the default arguments, the user may add customized arguments into this parser.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', type=str, help='path to the config file')
|
||||
@@ -63,26 +63,21 @@ def launch(config: Union[str, Path, Config, Dict],
|
||||
"""This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input
|
||||
arguments are not given. Then initialize and set distributed environment by calling global_context's functions.
|
||||
|
||||
:param config: Config file or config file path are both acceptable
|
||||
:type config: Union[str, dict, Config]
|
||||
:param rank: Rank for the default process group
|
||||
:type rank: int
|
||||
:param world_size: World size of the default process group
|
||||
:type world_size: int
|
||||
:param host: The master address for distributed training
|
||||
:type host: str
|
||||
:param port: The master port for distributed training
|
||||
:type port: str
|
||||
:param backend: Backend for torch.distributed
|
||||
:type backend: str, optional
|
||||
:param local_rank: Rank for the process on the node and is used to set the default CUDA device, defaults to None.
|
||||
If local_rank = None, the default device ordinal will be calculated automatically
|
||||
:type local_rank: int, optional
|
||||
:param seed: Specified random seed for every processes
|
||||
:type seed: int, optional
|
||||
:param verbose: Whether to print logs
|
||||
:type verbose: bool, optional
|
||||
:raises Exception: Raise exception when config type is wrong
|
||||
Args:
|
||||
config (Union[str, dict, Config]): Config file or config file path are both acceptable
|
||||
rank (int): Rank for the default process group
|
||||
world_size (int): World size of the default process group
|
||||
host (str): The master address for distributed training
|
||||
port (str): The master port for distributed training
|
||||
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
|
||||
local_rank (int, optional):
|
||||
Rank for the process on the node and is used to set the default CUDA device,
|
||||
defaults to None. If local_rank = None, the default device ordinal will be calculated automatically.
|
||||
seed (int, optional): Specified random seed for every process. Defaults to 1024.
|
||||
verbose (bool, optional): Whether to print logs. Defaults to True.
|
||||
|
||||
Raises:
|
||||
Exception: Raise exception when config type is wrong
|
||||
"""
|
||||
gpc.verbose = verbose
|
||||
|
||||
@@ -126,18 +121,13 @@ def launch_from_slurm(config: Union[str, Path, Config, Dict],
|
||||
"""A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables
|
||||
set by SLURM
|
||||
|
||||
:param config: Config file or config file path are both acceptable
|
||||
:type config: Union[str, dict, Config]
|
||||
:param host: The master address for distributed training
|
||||
:type host: str
|
||||
:param port: The master port for distributed training
|
||||
:type port: str
|
||||
:param backend: Backend for torch.distributed
|
||||
:type backend: str, optional
|
||||
:param seed: Specified random seed for every processes
|
||||
:type seed: int, optional
|
||||
:param verbose: Whether to print logs
|
||||
:type verbose: bool, optional
|
||||
Args:
|
||||
config (Union[str, dict, Config]): Config file or config file path are both acceptable
|
||||
host (str): The master address for distributed training
|
||||
port (str): The master port for distributed training
|
||||
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
|
||||
seed (int, optional): Specified random seed for every process. Defaults to 1024.
|
||||
verbose (bool, optional): Whether to print logs. Defaults to True.
|
||||
"""
|
||||
rank = int(os.environ['SLURM_PROCID'])
|
||||
world_size = int(os.environ['SLURM_NPROCS'])
|
||||
@@ -160,18 +150,13 @@ def launch_from_openmpi(config: Union[str, Path, Config, Dict],
|
||||
"""A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables
|
||||
set by OpenMPI
|
||||
|
||||
:param config: Config file or config file path are both acceptable
|
||||
:type config: Union[str, dict, Config]
|
||||
:param host: The master address for distributed training
|
||||
:type host: str
|
||||
:param port: The master port for distributed training
|
||||
:type port: str
|
||||
:param backend: Backend for torch.distributed
|
||||
:type backend: str, optional
|
||||
:param seed: Specified random seed for every processes
|
||||
:type seed: int, optional
|
||||
:param verbose: Whether to print logs
|
||||
:type verbose: bool, optional
|
||||
Args:
|
||||
config (Union[str, dict, Config]): Config file or config file path are both acceptable
|
||||
host (str): The master address for distributed training
|
||||
port (str): The master port for distributed training
|
||||
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
|
||||
seed (int, optional): Specified random seed for every process. Defaults to 1024.
|
||||
verbose (bool, optional): Whether to print logs. Defaults to True.
|
||||
"""
|
||||
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
||||
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
||||
@@ -194,14 +179,11 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
|
||||
"""A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
|
||||
from the environment variables set by PyTorch
|
||||
|
||||
:param config: Config file or config file path are both acceptable
|
||||
:type config: Union[str, dict, Config]
|
||||
:param backend: Backend for torch.distributed
|
||||
:type backend: str, optional
|
||||
:param seed: Specified random seed for every processes
|
||||
:type seed: int, optional
|
||||
:param verbose: Whether to print logs
|
||||
:type verbose: bool, optional
|
||||
Args:
|
||||
config (Union[str, dict, Config]): Config file or config file path are both acceptable
|
||||
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
|
||||
seed (int, optional): Specified random seed for every process. Defaults to 1024.
|
||||
verbose (bool, optional): Whether to print logs. Defaults to True.
|
||||
"""
|
||||
rank = int(os.environ['RANK'])
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
@@ -230,22 +212,20 @@ def initialize(model: nn.Module,
|
||||
"""Core function to wrap the essential training components with our functionality based on the config which is
|
||||
loaded into gpc.config.
|
||||
|
||||
:param model: Your model instance or a function to build the model
|
||||
:type model: :class:`torch.nn.Module` or Callbale
|
||||
:param optimizer: Your optimizer instance
|
||||
:type optimizer: :class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`
|
||||
:param criterion: Your criterion instance
|
||||
:type criterion: :class:`torch.nn.modules.loss._Loss`, optional
|
||||
:param train_dataloader: Dataloader for training
|
||||
:type train_dataloader: :class:`torch.utils.data.DataLoader`, optional
|
||||
:param test_dataloader: Dataloader for testing
|
||||
:type test_dataloader: :class:`torch.utils.data.DataLoader`, optional
|
||||
:param lr_scheduler: Your lr scheduler instance, optional
|
||||
:type lr_scheduler: :class:`torch.nn.lr_scheduler._LRScheduler`, optional
|
||||
:param verbose: Whether to print logs
|
||||
:type verbose: bool, optional
|
||||
:return: (engine, train_dataloader, test_dataloader, lr_scheduler)
|
||||
:rtype: Tuple
|
||||
Args:
|
||||
model (:class:`torch.nn.Module` or Callbale): Your model instance or a function to build the model.
|
||||
optimizer (:class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`):
|
||||
Your optimizer instance.
|
||||
criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance.
|
||||
train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training.
|
||||
test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
|
||||
lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional.
|
||||
verbose (bool, optional): Whether to print logs.
|
||||
|
||||
Returns:
|
||||
Tuple (engine, train_dataloader, test_dataloader, lr_scheduler):
|
||||
A tuple of ``(engine, train_dataloader, test_dataloader, lr_scheduler)``
|
||||
where only ``engine`` could not be None.
|
||||
"""
|
||||
# get logger
|
||||
logger = get_dist_logger()
|
||||
|
Reference in New Issue
Block a user