[doc] improved docstring and assertion messages for the engine module (#871)

This commit is contained in:
Frank Lee
2022-04-26 10:00:18 +08:00
committed by GitHub
parent 1c34382678
commit 11f54c7b6b
9 changed files with 180 additions and 60 deletions

View File

@@ -1,9 +1,10 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Union
import torch.nn as nn
from torch import Tensor
from typing import Iterable, Any
from typing import Iterable, Any, Tuple
from colossalai.nn.optimizer import ColossalaiOptimizer
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import Optimizer
@@ -33,24 +34,54 @@ class GradAccumOptimizer(ColossalaiOptimizer):
self.model = model
self.is_torch_ddp = isinstance(self.model, DistributedDataParallel)
def zero_grad(self, *args, **kwargs):
def zero_grad(self, *args, **kwargs) -> None:
"""
Set all gradients to zero.
Args:
*args: positional arguments for the optimizer wrapped
**kwargs: keyword arguments for the optimizer wrapped
"""
if self.accumulate_step == 0:
self.optim.zero_grad(*args, **kwargs)
def step(self, *args, **kwargs):
def step(self, *args, **kwargs) -> None:
"""
Update the model parameters.
Args:
*args: positional arguments for the optimizer wrapped
**kwargs: keyword arguments for the optimizer wrapped
"""
if self.accumulate_step < self.accumulate_size:
return None
else:
self.accumulate_step = 0
return self.optim.step(*args, **kwargs)
def clip_grad_norm(self, model: nn.Module, max_norm: float):
def clip_grad_norm(self, model: nn.Module, max_norm: float) -> None:
"""
Clip gradients by norm.
Args:
model (:class:`torch.nn.Module`): a torch module instance
max_norm (float): the max norm for gradient clipping
"""
if self.accumulate_step < self.accumulate_size:
pass
else:
self.optim.clip_grad_norm(model, max_norm)
def backward(self, loss: Tensor):
def backward(self, loss: Tensor) -> None:
"""Execute backward pass.
Args:
loss (:class:`torch.Tensor`): the loss value.
"""
self.accumulate_step += 1
if self.is_torch_ddp:
@@ -62,7 +93,14 @@ class GradAccumOptimizer(ColossalaiOptimizer):
scaled_loss = loss / self.accumulate_size
self.optim.backward(scaled_loss)
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
"""Execute backward pass given the gradients of the output.
Args:
loss (:class:`torch.Tensor`): the loss value.
grad (:class:`torch.Tensor`): the output gradient.
"""
self.accumulate_step += 1
no_sync = self.is_torch_ddp and self.accumulate_step < self.accumulate_size
@@ -84,7 +122,7 @@ class GradAccumDataloader:
(e.g. Dali dataloader), this class will automatically consume (load data for nothing) the remaining 2 batches.
Args:
optim (``Iterable``): Your dataloader object for gradient accumulation.
dataloader (``Iterable``): Your dataloader object for gradient accumulation.
accumulate_size (int): The number of steps to accumulate gradients.
"""
@@ -96,15 +134,15 @@ class GradAccumDataloader:
def __getattr__(self, __name: str) -> Any:
return getattr(self.dataloader, __name)
def __len__(self):
def __len__(self) -> int:
return self.steps_per_epoch
def __iter__(self):
def __iter__(self) -> Iterable:
self._cur_step = 0
self._dataiter = iter(self.dataloader)
return self
def __next__(self) -> Any:
def __next__(self) -> Union[Tensor, Tuple[Tensor]]:
if self._cur_step < self.steps_per_epoch:
self._cur_step += 1
@@ -137,13 +175,30 @@ class GradAccumLrSchedulerByStep(_LRScheduler):
self.accumulate_step = 0
@staticmethod
def compute_effective_steps_per_epoch(dataloader: Iterable, accumulate_size: int):
def compute_effective_steps_per_epoch(dataloader: Iterable, accumulate_size: int) -> int:
"""
Computes the number of effective training iterations. An effective iteration is defined
as the the aggregation of <accumulate_size> iterations. For examples, if accumulate_size = 4,
then 4 iterations are considered as one effective iteration.
Args:
dataloader (``Iterable``): Your dataloader object for gradient accumulation.
accumulate_size (int): The number of steps to accumulate gradients.
"""
return len(dataloader) // accumulate_size
def __getattr__(self, __name: str) -> Any:
return getattr(self.lr_scheduler, __name)
def step(self, *args, **kwargs):
def step(self, *args, **kwargs) -> None:
"""
Update the learning rate.
Args:
*args: positional arguments for the lr scheduler wrapped.
**kwargs: keyword arguments for the lr scheduler wrapped.
"""
self.accumulate_step += 1
if self.accumulate_step < self.accumulate_size:
pass
@@ -151,19 +206,52 @@ class GradAccumLrSchedulerByStep(_LRScheduler):
self.accumulate_step = 0
self.lr_scheduler.step(*args, **kwargs)
def get_lr(self):
def get_lr(self) -> Tensor:
"""
Compute the next learning rate.
Returns:
Tensor: the upcoming learning rate.
"""
return self.lr_scheduler.get_lr()
def get_last_lr(self):
def get_last_lr(self) -> Tensor:
"""
Returns the current learning rate.
Returns:
Tensor: the current learning rate.
"""
return self.lr_scheduler.get_last_lr()
def print_lr(self, *args, **kwargs):
def print_lr(self, *args, **kwargs) -> None:
"""
Print he learning rate.
Args:
*args: positional arguments for the lr scheduler wrapped.
**kwargs: keyword arguments for the lr scheduler wrapped.
"""
self.lr_scheduler.print_lr(*args, **kwargs)
def state_dict(self) -> dict:
"""
Returns the states of the lr scheduler as dictionary.
Returns:
dict: the states of the lr scheduler.
"""
return self.lr_scheduler.state_dict()
def load_state_dict(self, state_dict: dict) -> None:
"""
Load the states of the lr scheduler from a dictionary object.
Returns:
dict: the states of the lr scheduler.
"""
self.lr_scheduler.load_state_dict(state_dict)
@@ -188,7 +276,11 @@ class GradAccumGradientHandler:
self.accumulate_size = accumulate_size
self.accumulate_step = 0
def handle_gradient(self):
def handle_gradient(self) -> None:
"""
Handle gradients reduction only in the last gradient accumulation step.
"""
self.accumulate_step += 1
if self.accumulate_step < self.accumulate_size:
pass