Refactored docstring to google style

This commit is contained in:
Liang Bowen
2022-03-25 13:02:39 +08:00
committed by アマデウス
parent 53b1b6e340
commit ec5086c49c
94 changed files with 3389 additions and 2982 deletions

View File

@@ -21,8 +21,8 @@ class ApexAMPOptimizer(ColossalaiOptimizer):
def backward(self, loss: Tensor):
"""Backward pass to get all gradients
:param loss: Loss computed by a loss function
:type loss: torch.Tensor
Args:
loss (torch.Tensor): Loss computed by a loss function
"""
with apex_amp.scale_loss(loss, self.optim) as scaled_loss:
scaled_loss.backward()
@@ -30,10 +30,9 @@ class ApexAMPOptimizer(ColossalaiOptimizer):
def clip_grad_norm(self, model: nn.Module, max_norm: float):
"""Clip gradients' norm
:param model: Your model object
:type model: torch.nn.Module
:param max_norm: The max norm value for gradient clipping
:type max_norm: float
Args:
model (torch.nn.Module): Your model object
max_norm (float): The max norm value for gradient clipping
"""
if max_norm > 0:
clip_grad_norm_fp32(apex_amp.master_params(self.optim), max_norm)