update examples and sphnix docs for the new api (#63)

This commit is contained in:
Frank Lee
2021-12-13 22:07:01 +08:00
committed by GitHub
parent 7d3711058f
commit 35813ed3c4
124 changed files with 1251 additions and 1462 deletions

View File

@@ -7,6 +7,18 @@ import apex.amp as apex_amp
def convert_to_apex_amp(model: nn.Module,
optimizer: Optimizer,
amp_config):
"""A helper function to wrap training components with Torch AMP modules
:param model: your model object
:type model: :class:`torch.nn.Module`
:param optimizer: your optimizer object
:type optimizer: :class:`torch.optim.Optimzer`
:param amp_config: configuration for nvidia apex
:type amp_config: :class:`colossalai.context.Config` or dict
:return: (model, optimizer)
:rtype: Tuple
"""
model, optimizer = apex_amp.initialize(model, optimizer, **amp_config)
optimizer = ApexAMPOptimizer(optimizer)
return model, optimizer

View File

@@ -13,11 +13,24 @@ from colossalai.utils import clip_grad_norm_fp32
class ApexAMPOptimizer(ColossalaiOptimizer):
''' A wrapper class for APEX optimizer and it implements apex-specific backward and clip_grad_norm
methods
'''
def backward(self, loss: Tensor):
"""
:param loss: loss computed by a loss function
:type loss: torch.Tensor
"""
with apex_amp.scale_loss(loss, self.optim) as scaled_loss:
scaled_loss.backward()
def clip_grad_norm(self, model: nn.Module, max_norm: float):
"""
:param model: your model object
:type model: torch.nn.Module
:param max_norm: the max norm value for gradient clipping
:type max_norm: float
"""
if max_norm > 0:
clip_grad_norm_fp32(apex_amp.master_params(self.optim), max_norm)