Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
This commit is contained in:
Frank Lee
2021-11-18 19:45:06 +08:00
committed by GitHub
parent 2b05de4c64
commit 3defa32aee
80 changed files with 2194 additions and 1584 deletions

View File

@@ -0,0 +1,58 @@
from torch import Tensor
from colossalai.builder import build_lr_scheduler
from colossalai.registry import HOOKS
from ._metric_hook import MetricHook
from .._trainer import Trainer
from ..metric import LearningRate
@HOOKS.register_module
class LRSchedulerHook(MetricHook):
"""Build LR scheduler
:param trainer: Trainer attached with current hook
:type trainer: Trainer
:param lr_scheduler_cfg: The config of LR scheduler
:type lr_scheduler_cfg: dict
:param by_epoch: If `True`, the LR will be scheduled every epoch. Else, the LR will be scheduled every batch. Defaults to `True`.
:type by_epoch: bool
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type priority: int, optional
"""
def __init__(self,
trainer: Trainer,
lr_scheduler_cfg: dict,
by_epoch: bool = True,
store_lr_in_state: bool = True,
priority: int = 1,
):
super().__init__(trainer=trainer, priority=priority)
self.by_epoch = by_epoch
if by_epoch:
total_steps = trainer.max_epochs
else:
total_steps = trainer.max_epochs * trainer.steps_per_epoch
if trainer.max_steps is not None:
total_steps = min(total_steps, trainer.max_steps)
lr_scheduler_cfg['total_steps'] = total_steps
self.lr_scheduler = build_lr_scheduler(
lr_scheduler_cfg, trainer.engine.optimizer)
if store_lr_in_state:
self.trainer.states['metrics']['train']['lr'] = LearningRate(epoch_only=by_epoch,
initial_lr=self.lr_scheduler.get_lr()[0])
def after_train_epoch(self):
if self.by_epoch:
self.lr_scheduler.step()
self.trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_lr()[0])
def after_train_iter(self, output: Tensor, label: Tensor, loss: Tensor):
if not self.by_epoch:
self.lr_scheduler.step()
self.trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_lr()[0])