mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 11:44:03 +00:00
Refactored docstring to google style
This commit is contained in:
@@ -4,20 +4,30 @@ from torch.optim import Optimizer
|
||||
from colossalai.utils import is_no_pp_or_last_stage
|
||||
from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel
|
||||
from .grad_scaler import DynamicGradScaler, ConstantGradScaler
|
||||
from ._fp16_optimizer import FP16Optimizer
|
||||
|
||||
|
||||
def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
|
||||
"""A helper function to wrap training components with naive AMP modules
|
||||
"""A helper function to wrap training components with naive AMP modules. In this mode,
|
||||
we forcibly cast the model weights and inputs to FP16, and cast the model outputs to FP32 to calculate loss,
|
||||
which is equivalent to Apex O3.
|
||||
|
||||
:param model: your model object
|
||||
:type model: :class:`torch.nn.Module`
|
||||
:param optimizer: your optimizer object
|
||||
:type optimizer: :class:`torch.optim.Optimizer`
|
||||
:param amp_config: configuration for naive mode amp
|
||||
:type amp_config: :class:`colossalai.context.Config` or dict
|
||||
Args:
|
||||
model (:class:`torch.nn.Module`): your model object
|
||||
optimizer (:class:`torch.optim.Optimizer`): your optimizer object
|
||||
amp_config (:class:`colossalai.context.Config` or dict): configuration for naive mode amp.
|
||||
|
||||
:return: (model, optimizer)
|
||||
:rtype: Tuple
|
||||
|
||||
The ``amp_config`` should contain parameters below:
|
||||
:
|
||||
|
||||
verbose (bool, optional): if set to `True`, will print debug info (Default: False).
|
||||
clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0).
|
||||
Note that clipping is ignored if clip_grad == 0.
|
||||
dynamic_grad_scale (bool): whether to use dynamic grad scaler.
|
||||
|
||||
Returns:
|
||||
Tuples: A tuple (model, optimizer)
|
||||
"""
|
||||
if isinstance(model, nn.ModuleList):
|
||||
# interleaved pipeline
|
||||
@@ -46,4 +56,4 @@ def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
|
||||
return model, optimizer
|
||||
|
||||
|
||||
__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer']
|
||||
__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer', 'FP16Optimizer']
|
||||
|
Reference in New Issue
Block a user