mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b7699
.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from torch.optim.lr_scheduler import LambdaLR as _LambdaLR
|
||||
from torch.optim.lr_scheduler import MultiplicativeLR as _MultiplicativeLR
|
||||
from torch.optim.lr_scheduler import StepLR as _StepLR
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.optim.lr_scheduler import ExponentialLR as _ExponentialLR
|
||||
|
||||
from colossalai.registry import LR_SCHEDULERS
|
||||
|
||||
@@ -25,11 +25,8 @@ class LambdaLR(_LambdaLR):
|
||||
:type last_epoch: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, total_steps, lr_lambda=None, num_steps_per_epoch: int = -1,
|
||||
last_epoch: int = -1) -> None:
|
||||
def func(step): return lr_lambda(step // num_steps_per_epoch)
|
||||
|
||||
super().__init__(optimizer, func, last_epoch=last_epoch)
|
||||
def __init__(self, optimizer, total_steps, lr_lambda=None, last_epoch: int = -1) -> None:
|
||||
super().__init__(optimizer, lr_lambda, last_epoch=last_epoch)
|
||||
|
||||
|
||||
@LR_SCHEDULERS.register_module
|
||||
@@ -51,11 +48,8 @@ class MultiplicativeLR(_MultiplicativeLR):
|
||||
:type last_epoch: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, total_steps, lr_lambda=None, num_steps_per_epoch: int = -1,
|
||||
last_epoch: int = -1) -> None:
|
||||
def func(step): return lr_lambda(step // num_steps_per_epoch)
|
||||
|
||||
super().__init__(optimizer, func, last_epoch=last_epoch)
|
||||
def __init__(self, optimizer, total_steps, lr_lambda=None, last_epoch: int = -1) -> None:
|
||||
super().__init__(optimizer, lr_lambda, last_epoch=last_epoch)
|
||||
|
||||
|
||||
@LR_SCHEDULERS.register_module
|
||||
@@ -79,14 +73,13 @@ class StepLR(_StepLR):
|
||||
:type last_epoch: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, total_steps, step_size: int = 1, gamma: float = 0.1, num_steps_per_epoch: int = -1,
|
||||
last_epoch: int = -1) -> None:
|
||||
super().__init__(optimizer, step_size * num_steps_per_epoch,
|
||||
def __init__(self, optimizer, total_steps, step_size: int = 1, gamma: float = 0.1, last_epoch: int = -1) -> None:
|
||||
super().__init__(optimizer, step_size,
|
||||
gamma=gamma, last_epoch=last_epoch)
|
||||
|
||||
|
||||
@LR_SCHEDULERS.register_module
|
||||
class ExponentialLR(_LRScheduler):
|
||||
class ExponentialLR(_ExponentialLR):
|
||||
"""Decays the learning rate of each parameter group by gamma every epoch.
|
||||
When last_epoch=-1, sets initial lr as lr
|
||||
|
||||
@@ -102,21 +95,6 @@ class ExponentialLR(_LRScheduler):
|
||||
:type last_epoch: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, total_steps, gamma: float = 1.0, num_steps_per_epoch: int = -1,
|
||||
def __init__(self, optimizer, total_steps, gamma: float = 1.0,
|
||||
last_epoch: int = -1) -> None:
|
||||
self.gamma = gamma
|
||||
self.num_steps_per_epoch = num_steps_per_epoch
|
||||
super().__init__(optimizer, last_epoch=last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch == 0:
|
||||
return self.base_lrs
|
||||
elif (self.last_epoch + 1) % self.num_steps_per_epoch == 0:
|
||||
return [group['lr'] * self.gamma
|
||||
for group in self.optimizer.param_groups]
|
||||
return [group['lr']
|
||||
for group in self.optimizer.param_groups]
|
||||
|
||||
def _get_closed_form_lr(self):
|
||||
return [base_lr * self.gamma ** (self.last_epoch // self.num_steps_per_epoch)
|
||||
for base_lr in self.base_lrs]
|
||||
super().__init__(optimizer, gamma, last_epoch=last_epoch)
|
||||
|
Reference in New Issue
Block a user