mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 23:18:36 +00:00
Refactored docstring to google style
This commit is contained in:
@@ -4,17 +4,33 @@ from torch.optim import Optimizer
|
||||
|
||||
|
||||
def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config):
|
||||
"""A helper function to wrap training components with Apex AMP modules
|
||||
r"""A helper function to wrap training components with Apex AMP modules
|
||||
|
||||
:param model: your model object
|
||||
:type model: :class:`torch.nn.Module`
|
||||
:param optimizer: your optimizer object
|
||||
:type optimizer: :class:`torch.optim.Optimizer`
|
||||
:param amp_config: configuration for nvidia apex
|
||||
:type amp_config: :class:`colossalai.context.Config` or dict
|
||||
Args:
|
||||
model (:class:`torch.nn.Module`): your model object.
|
||||
optimizer (:class:`torch.optim.Optimizer`): your optimizer object.
|
||||
amp_config (:class: colossalai.context.Config or dict): configuration for initializing apex_amp.
|
||||
|
||||
:return: (model, optimizer)
|
||||
:rtype: Tuple
|
||||
The ``amp_config`` should include parameters below:
|
||||
::
|
||||
|
||||
enabled (bool, optional, default=True)
|
||||
opt_level (str, optional, default="O1")
|
||||
cast_model_type (``torch.dtype``, optional, default=None)
|
||||
patch_torch_functions (bool, optional, default=None)
|
||||
keep_batchnorm_fp32 (bool or str, optional, default=None
|
||||
master_weights (bool, optional, default=None)
|
||||
loss_scale (float or str, optional, default=None)
|
||||
cast_model_outputs (torch.dtype, optional, default=None)
|
||||
num_losses (int, optional, default=1)
|
||||
verbosity (int, default=1)
|
||||
min_loss_scale (float, default=None)
|
||||
max_loss_scale (float, default=2.**24)
|
||||
|
||||
Returns:
|
||||
Tuples: A tuple (model, optimizer).
|
||||
|
||||
More details about ``amp_config`` refer to `amp_config <https://nvidia.github.io/apex/amp.html?highlight=apex%20amp>`_.
|
||||
"""
|
||||
import apex.amp as apex_amp
|
||||
model, optimizer = apex_amp.initialize(model, optimizer, **amp_config)
|
||||
|
@@ -21,8 +21,8 @@ class ApexAMPOptimizer(ColossalaiOptimizer):
|
||||
def backward(self, loss: Tensor):
|
||||
"""Backward pass to get all gradients
|
||||
|
||||
:param loss: Loss computed by a loss function
|
||||
:type loss: torch.Tensor
|
||||
Args:
|
||||
loss (torch.Tensor): Loss computed by a loss function
|
||||
"""
|
||||
with apex_amp.scale_loss(loss, self.optim) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
@@ -30,10 +30,9 @@ class ApexAMPOptimizer(ColossalaiOptimizer):
|
||||
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
||||
"""Clip gradients' norm
|
||||
|
||||
:param model: Your model object
|
||||
:type model: torch.nn.Module
|
||||
:param max_norm: The max norm value for gradient clipping
|
||||
:type max_norm: float
|
||||
Args:
|
||||
model (torch.nn.Module): Your model object
|
||||
max_norm (float): The max norm value for gradient clipping
|
||||
"""
|
||||
if max_norm > 0:
|
||||
clip_grad_norm_fp32(apex_amp.master_params(self.optim), max_norm)
|
||||
|
Reference in New Issue
Block a user