mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
Refactored docstring to google style
This commit is contained in:
@@ -21,8 +21,8 @@ class ApexAMPOptimizer(ColossalaiOptimizer):
|
||||
def backward(self, loss: Tensor):
|
||||
"""Backward pass to get all gradients
|
||||
|
||||
:param loss: Loss computed by a loss function
|
||||
:type loss: torch.Tensor
|
||||
Args:
|
||||
loss (torch.Tensor): Loss computed by a loss function
|
||||
"""
|
||||
with apex_amp.scale_loss(loss, self.optim) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
@@ -30,10 +30,9 @@ class ApexAMPOptimizer(ColossalaiOptimizer):
|
||||
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
||||
"""Clip gradients' norm
|
||||
|
||||
:param model: Your model object
|
||||
:type model: torch.nn.Module
|
||||
:param max_norm: The max norm value for gradient clipping
|
||||
:type max_norm: float
|
||||
Args:
|
||||
model (torch.nn.Module): Your model object
|
||||
max_norm (float): The max norm value for gradient clipping
|
||||
"""
|
||||
if max_norm > 0:
|
||||
clip_grad_norm_fp32(apex_amp.master_params(self.optim), max_norm)
|
||||
|
Reference in New Issue
Block a user