Refactored docstring to google style

This commit is contained in:
Liang Bowen
2022-03-25 13:02:39 +08:00
committed by アマデウス
parent 53b1b6e340
commit ec5086c49c
94 changed files with 3389 additions and 2982 deletions

View File

@@ -4,20 +4,30 @@ from torch.optim import Optimizer
from colossalai.utils import is_no_pp_or_last_stage
from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel
from .grad_scaler import DynamicGradScaler, ConstantGradScaler
from ._fp16_optimizer import FP16Optimizer
def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
"""A helper function to wrap training components with naive AMP modules
"""A helper function to wrap training components with naive AMP modules. In this mode,
we forcibly cast the model weights and inputs to FP16, and cast the model outputs to FP32 to calculate loss,
which is equivalent to Apex O3.
:param model: your model object
:type model: :class:`torch.nn.Module`
:param optimizer: your optimizer object
:type optimizer: :class:`torch.optim.Optimizer`
:param amp_config: configuration for naive mode amp
:type amp_config: :class:`colossalai.context.Config` or dict
Args:
model (:class:`torch.nn.Module`): your model object
optimizer (:class:`torch.optim.Optimizer`): your optimizer object
amp_config (:class:`colossalai.context.Config` or dict): configuration for naive mode amp.
:return: (model, optimizer)
:rtype: Tuple
The ``amp_config`` should contain parameters below:
:
verbose (bool, optional): if set to `True`, will print debug info (Default: False).
clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0).
Note that clipping is ignored if clip_grad == 0.
dynamic_grad_scale (bool): whether to use dynamic grad scaler.
Returns:
Tuples: A tuple (model, optimizer)
"""
if isinstance(model, nn.ModuleList):
# interleaved pipeline
@@ -46,4 +56,4 @@ def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
return model, optimizer
__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer']
__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer', 'FP16Optimizer']