Refactored docstring to google style

This commit is contained in:
Liang Bowen
2022-03-25 13:02:39 +08:00
committed by アマデウス
parent 53b1b6e340
commit ec5086c49c
94 changed files with 3389 additions and 2982 deletions

View File

@@ -4,17 +4,33 @@ from torch.optim import Optimizer
def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config):
"""A helper function to wrap training components with Apex AMP modules
r"""A helper function to wrap training components with Apex AMP modules
:param model: your model object
:type model: :class:`torch.nn.Module`
:param optimizer: your optimizer object
:type optimizer: :class:`torch.optim.Optimizer`
:param amp_config: configuration for nvidia apex
:type amp_config: :class:`colossalai.context.Config` or dict
Args:
model (:class:`torch.nn.Module`): your model object.
optimizer (:class:`torch.optim.Optimizer`): your optimizer object.
amp_config (:class: colossalai.context.Config or dict): configuration for initializing apex_amp.
:return: (model, optimizer)
:rtype: Tuple
The ``amp_config`` should include parameters below:
::
enabled (bool, optional, default=True)
opt_level (str, optional, default="O1")
cast_model_type (``torch.dtype``, optional, default=None)
patch_torch_functions (bool, optional, default=None)
keep_batchnorm_fp32 (bool or str, optional, default=None
master_weights (bool, optional, default=None)
loss_scale (float or str, optional, default=None)
cast_model_outputs (torch.dtype, optional, default=None)
num_losses (int, optional, default=1)
verbosity (int, default=1)
min_loss_scale (float, default=None)
max_loss_scale (float, default=2.**24)
Returns:
Tuples: A tuple (model, optimizer).
More details about ``amp_config`` refer to `amp_config <https://nvidia.github.io/apex/amp.html?highlight=apex%20amp>`_.
"""
import apex.amp as apex_amp
model, optimizer = apex_amp.initialize(model, optimizer, **amp_config)

View File

@@ -21,8 +21,8 @@ class ApexAMPOptimizer(ColossalaiOptimizer):
def backward(self, loss: Tensor):
"""Backward pass to get all gradients
:param loss: Loss computed by a loss function
:type loss: torch.Tensor
Args:
loss (torch.Tensor): Loss computed by a loss function
"""
with apex_amp.scale_loss(loss, self.optim) as scaled_loss:
scaled_loss.backward()
@@ -30,10 +30,9 @@ class ApexAMPOptimizer(ColossalaiOptimizer):
def clip_grad_norm(self, model: nn.Module, max_norm: float):
"""Clip gradients' norm
:param model: Your model object
:type model: torch.nn.Module
:param max_norm: The max norm value for gradient clipping
:type max_norm: float
Args:
model (torch.nn.Module): Your model object
max_norm (float): The max norm value for gradient clipping
"""
if max_norm > 0:
clip_grad_norm_fp32(apex_amp.master_params(self.optim), max_norm)