Fixed docstring in colossalai (#171)

This commit is contained in:
HELSON
2022-01-21 10:44:30 +08:00
committed by GitHub
parent e2089c5c15
commit 0f8c7f9804
77 changed files with 983 additions and 603 deletions

View File

@@ -6,6 +6,8 @@ import socket
import torch
from torch._six import inf
import colossalai.context.parallel_mode
try:
import colossal_C
except:
@@ -23,11 +25,13 @@ from .multi_tensor_apply import multi_tensor_applier
def print_rank_0(msg: str, logger=None):
'''Print messages and save logs(optional). This is executed only if you are the rank-0 gpu.
"""Print messages and save logs(optional). This is executed only if you are the rank-0 gpu.
:param msg: A str message to output
:param logger: python logger object, defaults to None
'''
:param msg: A string message to output
:type msg: str
:param logger: Python logger object, defaults to None
:type logger: optional
"""
if gpc.get_global_rank() == 0:
if logger is None:
print(msg, flush=True)
@@ -48,10 +52,13 @@ def free_port():
def sync_model_param(model, parallel_mode):
'''Make sure data parameters are consistent during Data Parallel Mode
"""Make sure data parameters are consistent during Data Parallel Mode
:param model: A pyTorch nn.model on whose parameters you check the consistency
'''
:param parallel_mode: Parallel mode to be checked
:type model: torch.nn.Module
:type parallel_mode: colossalai.context.ParallelMode
"""
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
for param in model.parameters():
ranks = gpc.get_ranks_in_group(parallel_mode)
@@ -124,18 +131,17 @@ def _calc_lp(grads, norm_type):
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
"""Clips gradient norm of an iterable of parameters whose gradients
are in fp32.
"""Clips gradient norm of an iterable of parameters whose gradients are in fp32.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
This is adapted from :func:`torch.nn.utils.clip_grad.clip_grad_norm_` and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
:param parameters: an iterable of Tensors or a single Tensor that will have gradients normalized
:param parameters: An iterable of Tensors or a single Tensor that will have gradients normalized
:type parameters: (Iterable[Tensor] or Tensor)
:param max_norm: max norm of the gradients
:param max_norm: Max norm of the gradients
:type max_norm: float or int
:param norm_type: type of the used p-norm. Can be ``'inf'`` for infinity norm.
:param norm_type: Type of the used p-norm. Can be ``'inf'`` for infinity norm.
:type norm_type: float or int
:return: Total norm of the parameters (viewed as a single vector).