mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-05 01:59:57 +00:00
[doc] improved docstring in the amp module (#857)
This commit is contained in:
parent
b862d89d00
commit
9fdebadd69
@ -11,6 +11,9 @@ def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config):
|
|||||||
optimizer (:class:`torch.optim.Optimizer`): your optimizer object.
|
optimizer (:class:`torch.optim.Optimizer`): your optimizer object.
|
||||||
amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for initializing apex_amp.
|
amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for initializing apex_amp.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple: A tuple (model, optimizer).
|
||||||
|
|
||||||
The ``amp_config`` should include parameters below:
|
The ``amp_config`` should include parameters below:
|
||||||
::
|
::
|
||||||
|
|
||||||
@ -27,9 +30,6 @@ def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config):
|
|||||||
min_loss_scale (float, default=None)
|
min_loss_scale (float, default=None)
|
||||||
max_loss_scale (float, default=2.**24)
|
max_loss_scale (float, default=2.**24)
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuples: A tuple (model, optimizer).
|
|
||||||
|
|
||||||
More details about ``amp_config`` refer to `amp_config <https://nvidia.github.io/apex/amp.html?highlight=apex%20amp>`_.
|
More details about ``amp_config`` refer to `amp_config <https://nvidia.github.io/apex/amp.html?highlight=apex%20amp>`_.
|
||||||
"""
|
"""
|
||||||
import apex.amp as apex_amp
|
import apex.amp as apex_amp
|
||||||
|
@ -28,7 +28,7 @@ class ApexAMPOptimizer(ColossalaiOptimizer):
|
|||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
|
|
||||||
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
||||||
"""Clip gradients' norm
|
"""Clip gradients by norm
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (torch.nn.Module): Your model object
|
model (torch.nn.Module): Your model object
|
||||||
|
@ -17,6 +17,8 @@ def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
|
|||||||
optimizer (:class:`torch.optim.Optimizer`): your optimizer object
|
optimizer (:class:`torch.optim.Optimizer`): your optimizer object
|
||||||
amp_config (:class:`colossalai.context.Config` or dict): configuration for naive mode amp.
|
amp_config (:class:`colossalai.context.Config` or dict): configuration for naive mode amp.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple: A tuple (model, optimizer)
|
||||||
|
|
||||||
The ``amp_config`` should contain parameters below::
|
The ``amp_config`` should contain parameters below::
|
||||||
|
|
||||||
@ -24,9 +26,6 @@ def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
|
|||||||
clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0).
|
clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0).
|
||||||
Note that clipping is ignored if clip_grad == 0.
|
Note that clipping is ignored if clip_grad == 0.
|
||||||
dynamic_grad_scale (bool): whether to use dynamic grad scaler.
|
dynamic_grad_scale (bool): whether to use dynamic grad scaler.
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuples: A tuple (model, optimizer)
|
|
||||||
"""
|
"""
|
||||||
if isinstance(model, nn.ModuleList):
|
if isinstance(model, nn.ModuleList):
|
||||||
# interleaved pipeline
|
# interleaved pipeline
|
||||||
|
@ -152,18 +152,39 @@ class FP16Optimizer(Optimizer):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def grad_scaler(self):
|
def grad_scaler(self):
|
||||||
|
"""Returns the gradient scaler.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:class:`BaseGradScaler`: gradient scaler.
|
||||||
|
"""
|
||||||
|
|
||||||
return self._grad_scaler
|
return self._grad_scaler
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loss_scale(self):
|
def loss_scale(self):
|
||||||
|
"""Returns the loss scale.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: loss scale.
|
||||||
|
"""
|
||||||
return self._grad_scaler.scale
|
return self._grad_scaler.scale
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def optimizer(self):
|
def optimizer(self):
|
||||||
|
"""Returns the optimizer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:class:`torch.optim.Optimizer`: the optimizer object wrapped.
|
||||||
|
"""
|
||||||
return self._optimizer
|
return self._optimizer
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def defaults(self):
|
def defaults(self):
|
||||||
|
"""Returns the default arguments of optimizer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: optimizer arguments saved in defaults of the optimizer wrapped.
|
||||||
|
"""
|
||||||
return self._defaults
|
return self._defaults
|
||||||
|
|
||||||
def _check_overflow(self):
|
def _check_overflow(self):
|
||||||
@ -188,6 +209,12 @@ class FP16Optimizer(Optimizer):
|
|||||||
return self._found_overflow.item() > 0
|
return self._found_overflow.item() > 0
|
||||||
|
|
||||||
def zero_grad(self, set_to_none=True):
|
def zero_grad(self, set_to_none=True):
|
||||||
|
"""Set gradient to zero.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
set_to_none (bool): Whether set the gradient to None.
|
||||||
|
"""
|
||||||
|
|
||||||
# set_to_none = True can save some memory space
|
# set_to_none = True can save some memory space
|
||||||
for param_group in self._optimizer.param_groups:
|
for param_group in self._optimizer.param_groups:
|
||||||
zero_gard_by_list(param_group['params'], set_to_none=set_to_none)
|
zero_gard_by_list(param_group['params'], set_to_none=set_to_none)
|
||||||
@ -222,6 +249,9 @@ class FP16Optimizer(Optimizer):
|
|||||||
overflow_buf=self._dummy_overflow_buf)
|
overflow_buf=self._dummy_overflow_buf)
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
|
"""Update the model parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
# Copy gradients from model params to main params.
|
# Copy gradients from model params to main params.
|
||||||
self._assign_grad_to_fp32_master_param()
|
self._assign_grad_to_fp32_master_param()
|
||||||
self._unscale_grads()
|
self._unscale_grads()
|
||||||
@ -248,10 +278,19 @@ class FP16Optimizer(Optimizer):
|
|||||||
return True, grad_norm
|
return True, grad_norm
|
||||||
|
|
||||||
def backward(self, loss):
|
def backward(self, loss):
|
||||||
|
"""Execute backward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (:class:`torch.Tensor`): the loss value.
|
||||||
|
"""
|
||||||
|
|
||||||
scaled_loss = loss * self.grad_scaler.scale
|
scaled_loss = loss * self.grad_scaler.scale
|
||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
|
"""Returns the states of the fp16 optimizer as a dict object.
|
||||||
|
"""
|
||||||
|
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
state_dict['optimizer'] = self._optimizer.state_dict()
|
state_dict['optimizer'] = self._optimizer.state_dict()
|
||||||
if self.grad_scaler:
|
if self.grad_scaler:
|
||||||
@ -260,6 +299,12 @@ class FP16Optimizer(Optimizer):
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
|
"""Load the states of the fp16 optimizer from a dict object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (dict): the states of the fp16 optimizer
|
||||||
|
"""
|
||||||
|
|
||||||
# Optimizer.
|
# Optimizer.
|
||||||
self._optimizer.load_state_dict(state_dict['optimizer'])
|
self._optimizer.load_state_dict(state_dict['optimizer'])
|
||||||
|
|
||||||
@ -275,6 +320,11 @@ class FP16Optimizer(Optimizer):
|
|||||||
current_param.data.copy_(ckpt_param.data)
|
current_param.data.copy_(ckpt_param.data)
|
||||||
|
|
||||||
def clip_grad_norm(self, clip_grad):
|
def clip_grad_norm(self, clip_grad):
|
||||||
|
"""Clip gradients by norm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
clip_grad (float): the max norm for clipping
|
||||||
|
"""
|
||||||
params = []
|
params = []
|
||||||
for param_group in self._optimizer.param_groups:
|
for param_group in self._optimizer.param_groups:
|
||||||
for param in param_group['params']:
|
for param in param_group['params']:
|
||||||
|
@ -3,6 +3,14 @@ from torch import Tensor
|
|||||||
|
|
||||||
|
|
||||||
def has_inf_or_nan(tensor):
|
def has_inf_or_nan(tensor):
|
||||||
|
"""Check if tensor has inf or nan values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (:class:`torch.Tensor`): a torch tensor object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Whether the tensor has inf or nan. True for yes and False for no.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
# if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
|
# if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
|
||||||
# Pytorch's .sum() creates a one-element tensor of the same type as tensor
|
# Pytorch's .sum() creates a one-element tensor of the same type as tensor
|
||||||
@ -24,8 +32,8 @@ def has_inf_or_nan(tensor):
|
|||||||
|
|
||||||
|
|
||||||
def zero_gard_by_list(tensor_list: List[Tensor], set_to_none: bool = True) -> None:
|
def zero_gard_by_list(tensor_list: List[Tensor], set_to_none: bool = True) -> None:
|
||||||
"""
|
"""Clear the gradient of a list of tensors,
|
||||||
Clear the gradient of a list of tensors,
|
|
||||||
Note: copied from torch.optim.optimizer.
|
Note: copied from torch.optim.optimizer.
|
||||||
"""
|
"""
|
||||||
for param in tensor_list:
|
for param in tensor_list:
|
||||||
|
@ -11,6 +11,12 @@ __all__ = ['BaseGradScaler']
|
|||||||
|
|
||||||
|
|
||||||
class BaseGradScaler(ABC):
|
class BaseGradScaler(ABC):
|
||||||
|
"""A base class for the gradient scaler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_scale (float): the initial loss scale
|
||||||
|
verbose (bool): whether to log messages
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, initial_scale: float, verbose: bool):
|
def __init__(self, initial_scale: float, verbose: bool):
|
||||||
assert initial_scale > 0
|
assert initial_scale > 0
|
||||||
@ -22,24 +28,53 @@ class BaseGradScaler(ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def scale(self) -> Tensor:
|
def scale(self) -> Tensor:
|
||||||
|
"""Returns the loss scale.
|
||||||
|
"""
|
||||||
|
|
||||||
return self._scale
|
return self._scale
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def inv_scale(self) -> Tensor:
|
def inv_scale(self) -> Tensor:
|
||||||
|
"""Returns the inverse of the loss scale.
|
||||||
|
"""
|
||||||
|
|
||||||
return self._scale.double().reciprocal().float()
|
return self._scale.double().reciprocal().float()
|
||||||
|
|
||||||
def state_dict(self) -> Dict:
|
def state_dict(self) -> Dict:
|
||||||
|
"""Returns the states of the gradient scaler as a dict object.
|
||||||
|
"""
|
||||||
|
|
||||||
state_dict = dict()
|
state_dict = dict()
|
||||||
state_dict['scale'] = self.scale
|
state_dict['scale'] = self.scale
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Dict) -> None:
|
def load_state_dict(self, state_dict: Dict) -> None:
|
||||||
|
"""Load the states of the gradient scaler from a dict object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (dict): the states of the gradient scaler
|
||||||
|
"""
|
||||||
|
|
||||||
self._scale = state_dict['scale']
|
self._scale = state_dict['scale']
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def update(self, overflow: bool) -> None:
|
def update(self, overflow: bool) -> None:
|
||||||
|
"""Update the loss scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
overflow (bool): whether overflow occurs
|
||||||
|
"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def log(self, message, *args, **kwargs):
|
def log(self, message, *args, **kwargs):
|
||||||
|
"""Log messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): the message to log
|
||||||
|
*args: positional arguments for :class:`colossalai.logging.DistributedLogger`
|
||||||
|
**kwargs: key-word arguments for :class:`colossalai.logging.DistributedLogger`
|
||||||
|
"""
|
||||||
|
|
||||||
if self._verbose:
|
if self._verbose:
|
||||||
self._logger.info(message, *args, **kwargs)
|
self._logger.info(message, *args, **kwargs)
|
||||||
|
@ -6,11 +6,21 @@ __all__ = ['ConstantGradScaler']
|
|||||||
|
|
||||||
|
|
||||||
class ConstantGradScaler(BaseGradScaler):
|
class ConstantGradScaler(BaseGradScaler):
|
||||||
|
"""A gradient scaler which uses constant loss scale
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_scale (float): the initial loss scale
|
||||||
|
verbose (bool): whether to log messages
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, initial_scale: int, verbose: bool):
|
def __init__(self, initial_scale: int, verbose: bool):
|
||||||
super().__init__(initial_scale, verbose)
|
super().__init__(initial_scale, verbose)
|
||||||
self.log(f"Constant Gradient Scaler is initialized with scale {self.scale}", ranks=[0])
|
self.log(f"Constant Gradient Scaler is initialized with scale {self.scale}", ranks=[0])
|
||||||
|
|
||||||
def update(self, overflow: bool) -> None:
|
def update(self, overflow: bool) -> None:
|
||||||
# do nothing to maintain the current scale value
|
"""Do nothing to keep the loss scale constant.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
overflow (bool): whether overflow occurs
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
@ -9,6 +9,18 @@ __all__ = ['DynamicGradScaler']
|
|||||||
|
|
||||||
|
|
||||||
class DynamicGradScaler(BaseGradScaler):
|
class DynamicGradScaler(BaseGradScaler):
|
||||||
|
"""A gradient scaler which uses dynamic loss scale
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_scale (float): the initial loss scale, defaults to 2**16
|
||||||
|
growth_factor (float): the multiplication factor for increasing loss scale, defaults to 2
|
||||||
|
backoff_factor (float): the multiplication factor for decreasing loss scale, defaults to 0.5
|
||||||
|
growth_interval (int): the number of steps to increase loss scale when no overflow occurs, defaults to 1000
|
||||||
|
min_scale (float): the minimum loss scale, defaults to None
|
||||||
|
max_scale (float): the maximum loss scale, defaults to None
|
||||||
|
hysteresis (int): the number of overflows before decreasing loss scale, defaults to 2
|
||||||
|
verbose (bool): whether to log messages, defaults to False
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
initial_scale: float = 2**16,
|
initial_scale: float = 2**16,
|
||||||
@ -39,6 +51,9 @@ class DynamicGradScaler(BaseGradScaler):
|
|||||||
self._sanity_checks()
|
self._sanity_checks()
|
||||||
|
|
||||||
def _sanity_checks(self) -> None:
|
def _sanity_checks(self) -> None:
|
||||||
|
"""Check if the arguments are correct.
|
||||||
|
"""
|
||||||
|
|
||||||
if self._min_scale:
|
if self._min_scale:
|
||||||
assert self._min_scale > 0, 'The minimum gradient scale cannot be zero or negative'
|
assert self._min_scale > 0, 'The minimum gradient scale cannot be zero or negative'
|
||||||
if self._max_scale:
|
if self._max_scale:
|
||||||
@ -48,6 +63,11 @@ class DynamicGradScaler(BaseGradScaler):
|
|||||||
assert self._hysteresis >= 0, 'The hysteresis cannot be negative'
|
assert self._hysteresis >= 0, 'The hysteresis cannot be negative'
|
||||||
|
|
||||||
def update(self, overflow: bool) -> None:
|
def update(self, overflow: bool) -> None:
|
||||||
|
"""Update the loss scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
overflow (bool): whether overflow occurs
|
||||||
|
"""
|
||||||
if overflow:
|
if overflow:
|
||||||
self._hysteresis_step += 1
|
self._hysteresis_step += 1
|
||||||
self._growth_step = 0
|
self._growth_step = 0
|
||||||
@ -67,11 +87,17 @@ class DynamicGradScaler(BaseGradScaler):
|
|||||||
ranks=[0])
|
ranks=[0])
|
||||||
|
|
||||||
def _backoff_scale(self) -> None:
|
def _backoff_scale(self) -> None:
|
||||||
|
"""Decrease the loss scale
|
||||||
|
"""
|
||||||
|
|
||||||
self._scale = self._scale * self._backoff_factor
|
self._scale = self._scale * self._backoff_factor
|
||||||
if self._min_scale:
|
if self._min_scale:
|
||||||
self._scale = torch.max(self._scale, self._min_scale)
|
self._scale = torch.max(self._scale, self._min_scale)
|
||||||
|
|
||||||
def _grow_scale(self) -> None:
|
def _grow_scale(self) -> None:
|
||||||
|
"""Increase the loss scale
|
||||||
|
"""
|
||||||
|
|
||||||
self._scale = self._scale * self._growth_factor
|
self._scale = self._scale * self._growth_factor
|
||||||
if self._max_scale:
|
if self._max_scale:
|
||||||
self._scale = torch.min(self._scale, self._max_scale)
|
self._scale = torch.min(self._scale, self._max_scale)
|
||||||
|
@ -62,6 +62,9 @@ class TorchAMPOptimizer(ColossalaiOptimizer):
|
|||||||
class TorchAMPModel(nn.Module):
|
class TorchAMPModel(nn.Module):
|
||||||
"""A wrapper class for a model object which executes forward with values automatically
|
"""A wrapper class for a model object which executes forward with values automatically
|
||||||
cast to fp16
|
cast to fp16
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (:class:`torch.nn.Module`): a torch model instance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model: nn.Module) -> None:
|
def __init__(self, model: nn.Module) -> None:
|
||||||
@ -70,6 +73,9 @@ class TorchAMPModel(nn.Module):
|
|||||||
|
|
||||||
@torch_amp.autocast()
|
@torch_amp.autocast()
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Execute forward under the torch amp context
|
||||||
|
"""
|
||||||
return self.model(*args, **kwargs)
|
return self.model(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@ -86,4 +92,7 @@ class TorchAMPLoss(nn.Module):
|
|||||||
|
|
||||||
@torch_amp.autocast()
|
@torch_amp.autocast()
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Execute forward under the torch amp context
|
||||||
|
"""
|
||||||
return self.loss(*args, **kwargs)
|
return self.loss(*args, **kwargs)
|
||||||
|
Loading…
Reference in New Issue
Block a user