Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
This commit is contained in:
Frank Lee
2021-11-18 19:45:06 +08:00
committed by GitHub
parent 2b05de4c64
commit 3defa32aee
80 changed files with 2194 additions and 1584 deletions

View File

@@ -1,7 +1,7 @@
from torch.optim.lr_scheduler import LambdaLR as _LambdaLR
from torch.optim.lr_scheduler import MultiplicativeLR as _MultiplicativeLR
from torch.optim.lr_scheduler import StepLR as _StepLR
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ExponentialLR as _ExponentialLR
from colossalai.registry import LR_SCHEDULERS
@@ -25,11 +25,8 @@ class LambdaLR(_LambdaLR):
:type last_epoch: int, optional
"""
def __init__(self, optimizer, total_steps, lr_lambda=None, num_steps_per_epoch: int = -1,
last_epoch: int = -1) -> None:
def func(step): return lr_lambda(step // num_steps_per_epoch)
super().__init__(optimizer, func, last_epoch=last_epoch)
def __init__(self, optimizer, total_steps, lr_lambda=None, last_epoch: int = -1) -> None:
super().__init__(optimizer, lr_lambda, last_epoch=last_epoch)
@LR_SCHEDULERS.register_module
@@ -51,11 +48,8 @@ class MultiplicativeLR(_MultiplicativeLR):
:type last_epoch: int, optional
"""
def __init__(self, optimizer, total_steps, lr_lambda=None, num_steps_per_epoch: int = -1,
last_epoch: int = -1) -> None:
def func(step): return lr_lambda(step // num_steps_per_epoch)
super().__init__(optimizer, func, last_epoch=last_epoch)
def __init__(self, optimizer, total_steps, lr_lambda=None, last_epoch: int = -1) -> None:
super().__init__(optimizer, lr_lambda, last_epoch=last_epoch)
@LR_SCHEDULERS.register_module
@@ -79,14 +73,13 @@ class StepLR(_StepLR):
:type last_epoch: int, optional
"""
def __init__(self, optimizer, total_steps, step_size: int = 1, gamma: float = 0.1, num_steps_per_epoch: int = -1,
last_epoch: int = -1) -> None:
super().__init__(optimizer, step_size * num_steps_per_epoch,
def __init__(self, optimizer, total_steps, step_size: int = 1, gamma: float = 0.1, last_epoch: int = -1) -> None:
super().__init__(optimizer, step_size,
gamma=gamma, last_epoch=last_epoch)
@LR_SCHEDULERS.register_module
class ExponentialLR(_LRScheduler):
class ExponentialLR(_ExponentialLR):
"""Decays the learning rate of each parameter group by gamma every epoch.
When last_epoch=-1, sets initial lr as lr
@@ -102,21 +95,6 @@ class ExponentialLR(_LRScheduler):
:type last_epoch: int, optional
"""
def __init__(self, optimizer, total_steps, gamma: float = 1.0, num_steps_per_epoch: int = -1,
def __init__(self, optimizer, total_steps, gamma: float = 1.0,
last_epoch: int = -1) -> None:
self.gamma = gamma
self.num_steps_per_epoch = num_steps_per_epoch
super().__init__(optimizer, last_epoch=last_epoch)
def get_lr(self):
if self.last_epoch == 0:
return self.base_lrs
elif (self.last_epoch + 1) % self.num_steps_per_epoch == 0:
return [group['lr'] * self.gamma
for group in self.optimizer.param_groups]
return [group['lr']
for group in self.optimizer.param_groups]
def _get_closed_form_lr(self):
return [base_lr * self.gamma ** (self.last_epoch // self.num_steps_per_epoch)
for base_lr in self.base_lrs]
super().__init__(optimizer, gamma, last_epoch=last_epoch)