mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
update examples and sphnix docs for the new api (#63)
This commit is contained in:
@@ -13,11 +13,24 @@ from colossalai.utils import clip_grad_norm_fp32
|
||||
|
||||
|
||||
class ApexAMPOptimizer(ColossalaiOptimizer):
|
||||
''' A wrapper class for APEX optimizer and it implements apex-specific backward and clip_grad_norm
|
||||
methods
|
||||
'''
|
||||
|
||||
def backward(self, loss: Tensor):
|
||||
"""
|
||||
:param loss: loss computed by a loss function
|
||||
:type loss: torch.Tensor
|
||||
"""
|
||||
with apex_amp.scale_loss(loss, self.optim) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
|
||||
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
||||
"""
|
||||
:param model: your model object
|
||||
:type model: torch.nn.Module
|
||||
:param max_norm: the max norm value for gradient clipping
|
||||
:type max_norm: float
|
||||
"""
|
||||
if max_norm > 0:
|
||||
clip_grad_norm_fp32(apex_amp.master_params(self.optim), max_norm)
|
||||
|
Reference in New Issue
Block a user