mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
Migrated project
This commit is contained in:
13
colossalai/nn/lr_scheduler/__init__.py
Normal file
13
colossalai/nn/lr_scheduler/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from .cosine import CosineAnnealingLR, CosineAnnealingWarmupLR, FlatAnnealingLR, FlatAnnealingWarmupLR
|
||||
from .linear import LinearWarmupLR, LinearWarmupDecay
|
||||
from .multistep import MultiStepLR, MultiStepWarmupLR
|
||||
from .onecycle import OneCycleLR
|
||||
from .poly import PolynomialLR, PolynomialWarmupLR
|
||||
from .torch import LambdaLR, MultiplicativeLR, StepLR, ExponentialLR
|
||||
|
||||
__all__ = [
|
||||
'CosineAnnealingLR', 'CosineAnnealingWarmupLR', 'FlatAnnealingLR', 'FlatAnnealingWarmupLR', 'LinearWarmupLR',
|
||||
'MultiStepLR', 'MultiStepWarmupLR', 'OneCycleLR', 'PolynomialLR', 'PolynomialWarmupLR', 'LambdaLR',
|
||||
'MultiplicativeLR', 'StepLR',
|
||||
'ExponentialLR'
|
||||
]
|
129
colossalai/nn/lr_scheduler/cosine.py
Normal file
129
colossalai/nn/lr_scheduler/cosine.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR as _CosineAnnealingLR
|
||||
|
||||
from colossalai.registry import LR_SCHEDULERS
|
||||
from .delayed import DelayerScheduler, WarmupDelayerScheduler, WarmupScheduler
|
||||
|
||||
|
||||
@LR_SCHEDULERS.register_module
|
||||
class CosineAnnealingLR(_CosineAnnealingLR):
|
||||
r"""Set the learning rate of each parameter group using a cosine annealing
|
||||
schedule, where :math:`\eta_{max}` is set to the initial lr and
|
||||
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
\eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
|
||||
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
|
||||
& T_{cur} \neq (2k+1)T_{max}; \\
|
||||
\eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
|
||||
\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
|
||||
& T_{cur} = (2k+1)T_{max}.
|
||||
\end{aligned}
|
||||
|
||||
When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
|
||||
is defined recursively, the learning rate can be simultaneously modified
|
||||
outside this scheduler by other operators. If the learning rate is set
|
||||
solely by this scheduler, the learning rate at each step becomes:
|
||||
|
||||
.. math::
|
||||
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
|
||||
\cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)
|
||||
|
||||
It has been proposed in
|
||||
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
|
||||
implements the cosine annealing part of SGDR, and not the restarts.
|
||||
|
||||
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
||||
https://arxiv.org/abs/1608.03983
|
||||
|
||||
:param optimizer: Wrapped optimizer
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param total_steps: number of total training steps
|
||||
:type total_steps: int
|
||||
:param eta_min: Minimum learning rate, defaults to 0
|
||||
:type eta_min: int, optional
|
||||
:param last_epoch: The index of last epoch, defaults to -1
|
||||
:type last_epoch: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, total_steps: int, eta_min: int = 0, last_epoch: int = -1, **kwargs):
|
||||
super().__init__(optimizer, total_steps, eta_min=eta_min, last_epoch=last_epoch)
|
||||
|
||||
|
||||
@LR_SCHEDULERS.register_module
|
||||
class CosineAnnealingWarmupLR(WarmupScheduler):
|
||||
"""Cosine annealing learning rate scheduler with learning rate warmup. A linear warmup schedule will be applied.
|
||||
|
||||
:param optimizer: Wrapped optimizer
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param total_steps: number of total training steps
|
||||
:type total_steps: int
|
||||
:param warmup_steps: number of warmup steps, defaults to 0
|
||||
:type warmup_steps: int, optional
|
||||
:param eta_min: Minimum learning rate, defaults to 0
|
||||
:type eta_min: int, optional
|
||||
:param last_epoch: The index of last epoch, defaults to -1
|
||||
:type last_epoch: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: int = 0, last_epoch: int = -1,
|
||||
**kwargs):
|
||||
base_scheduler = _CosineAnnealingLR(
|
||||
optimizer, total_steps - warmup_steps, eta_min=eta_min)
|
||||
super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch)
|
||||
|
||||
|
||||
@LR_SCHEDULERS.register_module
|
||||
class FlatAnnealingLR(DelayerScheduler):
|
||||
"""Flat and cosine annealing learning rate scheduler. The learning rate will be a fixed value before starting decay.
|
||||
|
||||
:param optimizer: Wrapped optimizer
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param total_steps: number of total training steps
|
||||
:type total_steps: int
|
||||
:param pct_start: percent of steps before starting learning rate decay
|
||||
:type pct_start: float
|
||||
:param last_epoch: The index of last epoch, defaults to -1
|
||||
:type last_epoch: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, total_steps: int, pct_start: float = 0.72, last_epoch: int = -1, **kwargs):
|
||||
if not (0.0 <= pct_start <= 1.0):
|
||||
raise ValueError(
|
||||
f'pct_start must >= 0.0 and <= 1.0, got {pct_start}')
|
||||
flat_steps = int(total_steps * pct_start)
|
||||
anneal_steps = total_steps - flat_steps
|
||||
base_scheduler = _CosineAnnealingLR(
|
||||
optimizer, anneal_steps)
|
||||
super().__init__(optimizer, flat_steps, base_scheduler, last_epoch=last_epoch)
|
||||
|
||||
|
||||
@LR_SCHEDULERS.register_module
|
||||
class FlatAnnealingWarmupLR(WarmupDelayerScheduler):
|
||||
"""Flat and cosine annealing learning rate scheduler with learning rate warmup. A linear warmup schedule will be applied, and then the learning rate will be a fixed value before starting decay.
|
||||
|
||||
:param optimizer: Wrapped optimizer
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param total_steps: number of total training steps
|
||||
:type total_steps: int
|
||||
:param warmup_steps: number of warmup steps, defaults to 0
|
||||
:type warmup_steps: int, optional
|
||||
:param pct_start: percent of steps before starting learning rate decay
|
||||
:type pct_start: float
|
||||
:param eta_min: Minimum learning rate, defaults to 0
|
||||
:type eta_min: int, optional
|
||||
:param last_epoch: The index of last epoch, defaults to -1
|
||||
:type last_epoch: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, pct_start: float = 0.72, eta_min: int = 0,
|
||||
last_epoch: int = -1, **kwargs):
|
||||
if not (0.0 <= pct_start <= 1.0):
|
||||
raise ValueError(
|
||||
f'pct_start must >= 0.0 and <= 1.0, got {pct_start}')
|
||||
flat_steps = int((total_steps - warmup_steps) * pct_start)
|
||||
anneal_steps = total_steps - warmup_steps - flat_steps
|
||||
base_scheduler = _CosineAnnealingLR(
|
||||
optimizer, anneal_steps, eta_min=eta_min)
|
||||
super().__init__(optimizer, warmup_steps, flat_steps,
|
||||
base_scheduler, last_epoch=last_epoch)
|
149
colossalai/nn/lr_scheduler/delayed.py
Normal file
149
colossalai/nn/lr_scheduler/delayed.py
Normal file
@@ -0,0 +1,149 @@
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
|
||||
class _enable_get_lr_call:
|
||||
def __init__(self, o):
|
||||
self.o = o
|
||||
|
||||
def __enter__(self):
|
||||
self.o._get_lr_called_within_step = True
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.o._get_lr_called_within_step = False
|
||||
|
||||
|
||||
class DelayerScheduler(_LRScheduler):
|
||||
""" Starts with a flat lr schedule until it reaches N epochs the applies a scheduler
|
||||
|
||||
:param optimizer: Wrapped optimizer.
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param delay_epochs: number of epochs to keep the initial lr until starting aplying the scheduler
|
||||
:type delay_epochs: int
|
||||
:param after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
|
||||
:type after_scheduler: torch.optim.lr_scheduler
|
||||
:param last_epoch: The index of last epoch, defaults to -1
|
||||
:type last_epoch: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, delay_epochs, after_scheduler, last_epoch=-1):
|
||||
if delay_epochs < 0:
|
||||
raise ValueError(f'delay_epochs must >= 0, got {delay_epochs}')
|
||||
self.delay_epochs = delay_epochs
|
||||
self.after_scheduler = after_scheduler
|
||||
self.finished = False
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch >= self.delay_epochs:
|
||||
if not self.finished:
|
||||
self.after_scheduler.base_lrs = self.base_lrs
|
||||
self.finished = True
|
||||
with _enable_get_lr_call(self.after_scheduler):
|
||||
return self.after_scheduler.get_lr()
|
||||
|
||||
return self.base_lrs
|
||||
|
||||
def step(self, epoch=None):
|
||||
if self.finished:
|
||||
if epoch is None:
|
||||
self.after_scheduler.step(None)
|
||||
else:
|
||||
self.after_scheduler.step(epoch - self.delay_epochs)
|
||||
else:
|
||||
return super(DelayerScheduler, self).step(epoch)
|
||||
|
||||
|
||||
class WarmupScheduler(_LRScheduler):
|
||||
""" Starts with a linear warmup lr schedule until it reaches N epochs the applies a scheduler
|
||||
|
||||
:param optimizer: Wrapped optimizer.
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param warmup_epochs: number of epochs to linearly warmup lr until starting aplying the scheduler
|
||||
:type warmup_epochs: int
|
||||
:param after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
|
||||
:type after_scheduler: torch.optim.lr_scheduler
|
||||
:param last_epoch: The index of last epoch, defaults to -1
|
||||
:type last_epoch: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1):
|
||||
if warmup_epochs < 0:
|
||||
raise ValueError(f'warmup_epochs must >= 0, got {warmup_epochs}')
|
||||
self.warmup_epochs = warmup_epochs
|
||||
self.after_scheduler = after_scheduler
|
||||
self.finished = False
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch >= self.warmup_epochs:
|
||||
if not self.finished:
|
||||
self.after_scheduler.base_lrs = self.base_lrs
|
||||
# reset lr to base_lr
|
||||
for group, base_lr in zip(self.optimizer.param_groups, self.base_lrs):
|
||||
group['lr'] = base_lr
|
||||
self.finished = True
|
||||
with _enable_get_lr_call(self.after_scheduler):
|
||||
return self.after_scheduler.get_lr()
|
||||
|
||||
return [(self.last_epoch + 1) / (self.warmup_epochs + 1) * lr for lr in self.base_lrs]
|
||||
|
||||
def step(self, epoch=None):
|
||||
if self.finished:
|
||||
if epoch is None:
|
||||
self.after_scheduler.step(None)
|
||||
else:
|
||||
self.after_scheduler.step(epoch - self.warmup_epochs)
|
||||
else:
|
||||
return super().step(epoch)
|
||||
|
||||
|
||||
class WarmupDelayerScheduler(_LRScheduler):
|
||||
""" Starts with a linear warmup lr schedule until it reaches N epochs and a flat lr schedule until it reaches M epochs the applies a scheduler
|
||||
|
||||
:param optimizer: Wrapped optimizer.
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param warmup_epochs: number of epochs to linearly warmup lr until starting aplying the scheduler
|
||||
:type warmup_epochs: int
|
||||
:param delay_epochs: number of epochs to keep the initial lr until starting aplying the scheduler
|
||||
:type delay_epochs: int
|
||||
:param after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
|
||||
:type after_scheduler: torch.optim.lr_scheduler
|
||||
:param last_epoch: The index of last epoch, defaults to -1
|
||||
:type last_epoch: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, delay_epochs, after_scheduler, last_epoch=-1):
|
||||
if delay_epochs < 0:
|
||||
raise ValueError(f'delay_epochs must >= 0, got {delay_epochs}')
|
||||
if warmup_epochs < 0:
|
||||
raise ValueError(f'warmup_epochs must >= 0, got {warmup_epochs}')
|
||||
self.warmup_epochs = warmup_epochs
|
||||
self.delay_epochs = delay_epochs
|
||||
self.after_scheduler = after_scheduler
|
||||
self.finished = False
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch >= self.warmup_epochs + self.delay_epochs:
|
||||
if not self.finished:
|
||||
self.after_scheduler.base_lrs = self.base_lrs
|
||||
# reset lr to base_lr
|
||||
for group, base_lr in zip(self.optimizer.param_groups, self.base_lrs):
|
||||
group['lr'] = base_lr
|
||||
self.finished = True
|
||||
with _enable_get_lr_call(self.after_scheduler):
|
||||
return self.after_scheduler.get_lr()
|
||||
elif self.last_epoch >= self.warmup_epochs:
|
||||
return self.base_lrs
|
||||
|
||||
return [(self.last_epoch + 1) / self.warmup_epochs * lr for lr in self.base_lrs]
|
||||
|
||||
def step(self, epoch=None):
|
||||
if self.finished:
|
||||
if epoch is None:
|
||||
self.after_scheduler.step(None)
|
||||
else:
|
||||
self.after_scheduler.step(epoch - self.warmup_epochs)
|
||||
else:
|
||||
return super().step(epoch)
|
45
colossalai/nn/lr_scheduler/linear.py
Normal file
45
colossalai/nn/lr_scheduler/linear.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
from colossalai.registry import LR_SCHEDULERS
|
||||
|
||||
|
||||
@LR_SCHEDULERS.register_module
|
||||
class LinearWarmupLR(_LRScheduler):
|
||||
"""Linearly warmup learning rate and then linearly decay
|
||||
|
||||
:param optimizer: Wrapped optimizer
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param total_steps: number of total training steps
|
||||
:type total_steps: int
|
||||
:param warmup_steps: number of warmup steps, defaults to 0
|
||||
:type warmup_steps: int, optional
|
||||
:param last_epoch: The index of last epoch, defaults to -1
|
||||
:type last_epoch: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, last_epoch: int = -1, **kwargs):
|
||||
self.warmup_steps = warmup_steps
|
||||
self.total_steps = total_steps
|
||||
super().__init__(optimizer, last_epoch=last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch < self.warmup_steps:
|
||||
return [(self.last_epoch + 1) / (self.warmup_steps + 1) * lr for lr in self.base_lrs]
|
||||
else:
|
||||
return [(self.total_steps - self.last_epoch) / (self.total_steps - self.warmup_steps) * lr for lr in
|
||||
self.base_lrs]
|
||||
|
||||
|
||||
@LR_SCHEDULERS.register_module
|
||||
class LinearWarmupDecay(_LRScheduler):
|
||||
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, last_epoch: int = -1, **kwargs):
|
||||
self.warmup_steps = int(warmup_steps)
|
||||
self.total_steps = total_steps
|
||||
super().__init__(optimizer, last_epoch=last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch < self.warmup_steps:
|
||||
return [(self.last_epoch + 1) / self.warmup_steps * lr for lr in self.base_lrs]
|
||||
else:
|
||||
return [(self.total_steps - self.last_epoch - 1) / (self.total_steps - self.warmup_steps) * lr for lr in
|
||||
self.base_lrs]
|
70
colossalai/nn/lr_scheduler/multistep.py
Normal file
70
colossalai/nn/lr_scheduler/multistep.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from typing import List
|
||||
|
||||
from torch.optim.lr_scheduler import MultiStepLR as _MultiStepLR
|
||||
|
||||
from colossalai.registry import LR_SCHEDULERS
|
||||
from .delayed import WarmupScheduler
|
||||
|
||||
|
||||
@LR_SCHEDULERS.register_module
|
||||
class MultiStepLR(_MultiStepLR):
|
||||
"""Decays the learning rate of each parameter group by gamma once the
|
||||
number of epoch reaches one of the milestones. Notice that such decay can
|
||||
happen simultaneously with other changes to the learning rate from outside
|
||||
this scheduler. When last_epoch=-1, sets initial lr as lr.
|
||||
|
||||
:param optimizer: Wrapped optimizer
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param total_steps: number of total training steps
|
||||
:type total_steps: int
|
||||
:param milestones: List of epoch indices. Must be increasing, defaults to None
|
||||
:type milestones: List[int], optional
|
||||
:param gamma: Multiplicative factor of learning rate decay, defaults to 0.1
|
||||
:type gamma: float, optional
|
||||
:param num_steps_per_epoch: number of steps per epoch, defaults to -1
|
||||
:type num_steps_per_epoch: int, optional
|
||||
:param last_epoch: The index of last epoch, defaults to -1
|
||||
:type last_epoch: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, total_steps: int, milestones: List[int] = None, gamma: float = 0.1,
|
||||
num_steps_per_epoch: int = -1, last_epoch: int = -1, **kwargs):
|
||||
if num_steps_per_epoch <= 0:
|
||||
raise ValueError(
|
||||
f'num_steps_per_epoch must > 0, got {num_steps_per_epoch}')
|
||||
milestones = [v * num_steps_per_epoch for v in milestones]
|
||||
super().__init__(optimizer, milestones, gamma=gamma, last_epoch=last_epoch)
|
||||
|
||||
|
||||
@LR_SCHEDULERS.register_module
|
||||
class MultiStepWarmupLR(WarmupScheduler):
|
||||
"""Multi-step laerning rate scheduler with warmup.
|
||||
|
||||
:param optimizer: Wrapped optimizer
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param total_steps: number of total training steps
|
||||
:type total_steps: int
|
||||
:param warmup_steps: number of warmup steps, defaults to 0
|
||||
:type warmup_steps: int, optional
|
||||
:param milestones: List of epoch indices. Must be increasing, defaults to None
|
||||
:type milestones: List[int], optional
|
||||
:param gamma: Multiplicative factor of learning rate decay, defaults to 0.1
|
||||
:type gamma: float, optional
|
||||
:param num_steps_per_epoch: number of steps per epoch, defaults to -1
|
||||
:type num_steps_per_epoch: int, optional
|
||||
:param last_epoch: The index of last epoch, defaults to -1
|
||||
:type last_epoch: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, milestones: List[int] = None,
|
||||
gamma: float = 0.1, num_steps_per_epoch: int = -1, last_epoch: int = -1, **kwargs):
|
||||
if len(milestones) == 0:
|
||||
raise ValueError('milestones cannot be empty')
|
||||
if num_steps_per_epoch <= 0:
|
||||
raise ValueError(
|
||||
f'num_steps_per_epoch must > 0, got {num_steps_per_epoch}')
|
||||
milestones = [v * num_steps_per_epoch - warmup_steps for v in milestones if v *
|
||||
num_steps_per_epoch >= warmup_steps]
|
||||
base_scheduler = _MultiStepLR(optimizer, milestones=milestones,
|
||||
gamma=gamma)
|
||||
super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch)
|
97
colossalai/nn/lr_scheduler/onecycle.py
Normal file
97
colossalai/nn/lr_scheduler/onecycle.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from torch.optim.lr_scheduler import OneCycleLR as _OneCycleLR
|
||||
|
||||
from colossalai.registry import LR_SCHEDULERS
|
||||
|
||||
|
||||
@LR_SCHEDULERS.register_module
|
||||
class OneCycleLR(_OneCycleLR):
|
||||
r"""Sets the learning rate of each parameter group according to the
|
||||
1cycle learning rate policy. The 1cycle policy anneals the learning
|
||||
rate from an initial learning rate to some maximum learning rate and then
|
||||
from that maximum learning rate to some minimum learning rate much lower
|
||||
than the initial learning rate.
|
||||
This policy was initially described in the paper `Super-Convergence:
|
||||
Very Fast Training of Neural Networks Using Large Learning Rates`_.
|
||||
|
||||
The 1cycle learning rate policy changes the learning rate after every batch.
|
||||
`step` should be called after a batch has been used for training.
|
||||
|
||||
This scheduler is not chainable.
|
||||
|
||||
Note also that the total number of steps in the cycle can be determined in one
|
||||
of two ways (listed in order of precedence):
|
||||
|
||||
#. A value for total_steps is explicitly provided.
|
||||
#. A number of epochs (epochs) and a number of steps per epoch
|
||||
(steps_per_epoch) are provided.
|
||||
In this case, the number of total steps is inferred by
|
||||
total_steps = epochs * steps_per_epoch
|
||||
|
||||
You must either provide a value for total_steps or provide a value for both
|
||||
epochs and steps_per_epoch.
|
||||
|
||||
The default behaviour of this scheduler follows the fastai implementation of 1cycle, which
|
||||
claims that "unpublished work has shown even better results by using only two phases". To
|
||||
mimic the behaviour of the original paper instead, set ``three_phase=True``.
|
||||
|
||||
:param optimizer: Wrapped optimizer
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param total_steps: number of total training steps
|
||||
:type total_steps: int
|
||||
:param pct_start: The percentage of the cycle (in number of steps) spent increasing the learning rate, defaults to 0.3
|
||||
:type pct_start: float, optional
|
||||
:param anneal_strategy: {'cos', 'linear'}
|
||||
Specifies the annealing strategy: "cos" for cosine annealing, "linear" for
|
||||
linear annealing, defaults to 'cos'
|
||||
:type anneal_strategy: str, optional
|
||||
:param cycle_momentum: If ``True``, momentum is cycled inversely
|
||||
to learning rate between 'base_momentum' and 'max_momentum', defaults to True
|
||||
:type cycle_momentum: bool, optional
|
||||
:param base_momentum: Lower momentum boundaries in the cycle
|
||||
for each parameter group. Note that momentum is cycled inversely
|
||||
to learning rate; at the peak of a cycle, momentum is
|
||||
'base_momentum' and learning rate is 'max_lr', defaults to 0.85
|
||||
:type base_momentum: float, optional
|
||||
:param max_momentum: Upper momentum boundaries in the cycle
|
||||
for each parameter group. Functionally,
|
||||
it defines the cycle amplitude (max_momentum - base_momentum).
|
||||
Note that momentum is cycled inversely
|
||||
to learning rate; at the start of a cycle, momentum is 'max_momentum'
|
||||
and learning rate is 'base_lr', defaults to 0.95
|
||||
:type max_momentum: float, optional
|
||||
:param div_factor: Determines the initial learning rate via
|
||||
initial_lr = max_lr/div_factor, defaults to 25.0
|
||||
:type div_factor: float, optional
|
||||
:param final_div_factor: Determines the minimum learning rate via
|
||||
min_lr = initial_lr/final_div_factor, defaults to 10000.0
|
||||
:type final_div_factor: float, optional
|
||||
:param last_epoch: The index of the last batch. This parameter is used when
|
||||
resuming a training job. Since `step()` should be invoked after each
|
||||
batch instead of after each epoch, this number represents the total
|
||||
number of *batches* computed, not the total number of epochs computed.
|
||||
When last_epoch=-1, the schedule is started from the beginning, defaults to -1
|
||||
:type last_epoch: int, optional
|
||||
|
||||
.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
|
||||
https://arxiv.org/abs/1708.07120
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, total_steps: int,
|
||||
pct_start=0.3,
|
||||
anneal_strategy='cos',
|
||||
cycle_momentum=True,
|
||||
base_momentum=0.85,
|
||||
max_momentum=0.95,
|
||||
div_factor=25.0,
|
||||
final_div_factor=10000.0,
|
||||
last_epoch=-1, **kwargs):
|
||||
max_lrs = list(map(lambda group: group['lr'], optimizer.param_groups))
|
||||
super().__init__(optimizer, max_lrs, total_steps=total_steps,
|
||||
pct_start=pct_start,
|
||||
anneal_strategy=anneal_strategy,
|
||||
cycle_momentum=cycle_momentum,
|
||||
base_momentum=base_momentum,
|
||||
max_momentum=max_momentum,
|
||||
div_factor=div_factor,
|
||||
final_div_factor=final_div_factor,
|
||||
last_epoch=last_epoch)
|
65
colossalai/nn/lr_scheduler/poly.py
Normal file
65
colossalai/nn/lr_scheduler/poly.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
from colossalai.registry import LR_SCHEDULERS
|
||||
from .delayed import WarmupScheduler
|
||||
|
||||
|
||||
@LR_SCHEDULERS.register_module
|
||||
class PolynomialLR(_LRScheduler):
|
||||
"""Polynomial learning rate scheduler.
|
||||
|
||||
:param optimizer: Wrapped optimizer
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param total_steps: number of total training steps
|
||||
:type total_steps: int
|
||||
:param end_lr: Minimum learning rate, defaults to 0.0001
|
||||
:type end_lr: float, optional
|
||||
:param power: the power of polynomial, defaults to 1.0
|
||||
:type power: float, optional
|
||||
:param last_epoch: The index of last epoch, defaults to -1
|
||||
:type last_epoch: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, total_steps: int, end_lr: float = 0.0001, power: float = 1.0, last_epoch: int = -1,
|
||||
**kwargs):
|
||||
if end_lr < 0:
|
||||
raise ValueError(f'end_lr must >= 0, got {end_lr}')
|
||||
self.total_steps = total_steps
|
||||
self.end_lr = end_lr
|
||||
self.power = power
|
||||
super().__init__(optimizer, last_epoch=last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
return self._get_closed_form_lr()
|
||||
|
||||
def _get_closed_form_lr(self):
|
||||
return [
|
||||
(base_lr - self.end_lr) * ((1 - min(self.last_epoch, self.total_steps) /
|
||||
self.total_steps) ** self.power) + self.end_lr
|
||||
for base_lr in self.base_lrs
|
||||
]
|
||||
|
||||
|
||||
@LR_SCHEDULERS.register_module
|
||||
class PolynomialWarmupLR(WarmupScheduler):
|
||||
"""Polynomial learning rate scheduler with warmup.
|
||||
|
||||
:param optimizer: Wrapped optimizer
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param total_steps: number of total training steps
|
||||
:type total_steps: int
|
||||
:param warmup_steps: number of warmup steps, defaults to 0
|
||||
:type warmup_steps: int, optional
|
||||
:param end_lr: Minimum learning rate, defaults to 0.0001
|
||||
:type end_lr: float, optional
|
||||
:param power: the power of polynomial, defaults to 1.0
|
||||
:type power: float, optional
|
||||
:param last_epoch: The index of last epoch, defaults to -1
|
||||
:type last_epoch: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, end_lr: float = 0.0001, power: float = 1.0,
|
||||
last_epoch: int = -1, **kwargs):
|
||||
base_scheduler = PolynomialLR(
|
||||
optimizer, total_steps - warmup_steps, end_lr=end_lr, power=power)
|
||||
super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch)
|
122
colossalai/nn/lr_scheduler/torch.py
Normal file
122
colossalai/nn/lr_scheduler/torch.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from torch.optim.lr_scheduler import LambdaLR as _LambdaLR
|
||||
from torch.optim.lr_scheduler import MultiplicativeLR as _MultiplicativeLR
|
||||
from torch.optim.lr_scheduler import StepLR as _StepLR
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
from colossalai.registry import LR_SCHEDULERS
|
||||
|
||||
|
||||
@LR_SCHEDULERS.register_module
|
||||
class LambdaLR(_LambdaLR):
|
||||
"""Sets the learning rate of each parameter group to the initial lr
|
||||
times a given function. When last_epoch=-1, sets initial lr as lr.
|
||||
|
||||
:param optimizer: Wrapped optimizer
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param total_steps: number of total training steps
|
||||
:type total_steps: int
|
||||
:param lr_lambda: A function which computes a multiplicative
|
||||
factor given an integer parameter epoch, or a list of such
|
||||
functions, one for each group in optimizer.param_groups, defaults to None
|
||||
:type lr_lambda: function or list, optional
|
||||
:param num_steps_per_epoch: number of steps per epoch, defaults to -1
|
||||
:type num_steps_per_epoch: int, optional
|
||||
:param last_epoch: The index of last epoch, defaults to -1
|
||||
:type last_epoch: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, total_steps, lr_lambda=None, num_steps_per_epoch: int = -1,
|
||||
last_epoch: int = -1) -> None:
|
||||
def func(step): return lr_lambda(step // num_steps_per_epoch)
|
||||
|
||||
super().__init__(optimizer, func, last_epoch=last_epoch)
|
||||
|
||||
|
||||
@LR_SCHEDULERS.register_module
|
||||
class MultiplicativeLR(_MultiplicativeLR):
|
||||
"""Multiply the learning rate of each parameter group by the factor given
|
||||
in the specified function. When last_epoch=-1, sets initial lr as lr
|
||||
|
||||
:param optimizer: Wrapped optimizer
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param total_steps: number of total training steps
|
||||
:type total_steps: int
|
||||
:param lr_lambda: A function which computes a multiplicative
|
||||
factor given an integer parameter epoch, or a list of such
|
||||
functions, one for each group in optimizer.param_groups, defaults to None
|
||||
:type lr_lambda: function or list, optional
|
||||
:param num_steps_per_epoch: number of steps per epoch, defaults to -1
|
||||
:type num_steps_per_epoch: int, optional
|
||||
:param last_epoch: The index of last epoch, defaults to -1
|
||||
:type last_epoch: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, total_steps, lr_lambda=None, num_steps_per_epoch: int = -1,
|
||||
last_epoch: int = -1) -> None:
|
||||
def func(step): return lr_lambda(step // num_steps_per_epoch)
|
||||
|
||||
super().__init__(optimizer, func, last_epoch=last_epoch)
|
||||
|
||||
|
||||
@LR_SCHEDULERS.register_module
|
||||
class StepLR(_StepLR):
|
||||
"""Decays the learning rate of each parameter group by gamma every
|
||||
step_size epochs. Notice that such decay can happen simultaneously with
|
||||
other changes to the learning rate from outside this scheduler. When
|
||||
last_epoch=-1, sets initial lr as lr
|
||||
|
||||
:param optimizer: Wrapped optimizer
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param total_steps: number of total training steps
|
||||
:type total_steps: int
|
||||
:param step_size: Period of learning rate decay, defaults to 1
|
||||
:type step_size: int, optional
|
||||
:param gamma: Multiplicative factor of learning rate decay, defaults to 0.1
|
||||
:type gamma: float, optional
|
||||
:param num_steps_per_epoch: number of steps per epoch, defaults to -1
|
||||
:type num_steps_per_epoch: int, optional
|
||||
:param last_epoch: The index of last epoch, defaults to -1
|
||||
:type last_epoch: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, total_steps, step_size: int = 1, gamma: float = 0.1, num_steps_per_epoch: int = -1,
|
||||
last_epoch: int = -1) -> None:
|
||||
super().__init__(optimizer, step_size * num_steps_per_epoch,
|
||||
gamma=gamma, last_epoch=last_epoch)
|
||||
|
||||
|
||||
@LR_SCHEDULERS.register_module
|
||||
class ExponentialLR(_LRScheduler):
|
||||
"""Decays the learning rate of each parameter group by gamma every epoch.
|
||||
When last_epoch=-1, sets initial lr as lr
|
||||
|
||||
:param optimizer: Wrapped optimizer
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:param total_steps: number of total training steps
|
||||
:type total_steps: int
|
||||
:param gamma: Multiplicative factor of learning rate decay, defaults to 1.0
|
||||
:type gamma: float, optional
|
||||
:param num_steps_per_epoch: number of steps per epoch, defaults to -1
|
||||
:type num_steps_per_epoch: int, optional
|
||||
:param last_epoch: The index of last epoch, defaults to -1
|
||||
:type last_epoch: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, total_steps, gamma: float = 1.0, num_steps_per_epoch: int = -1,
|
||||
last_epoch: int = -1) -> None:
|
||||
self.gamma = gamma
|
||||
self.num_steps_per_epoch = num_steps_per_epoch
|
||||
super().__init__(optimizer, last_epoch=last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch == 0:
|
||||
return self.base_lrs
|
||||
elif (self.last_epoch + 1) % self.num_steps_per_epoch == 0:
|
||||
return [group['lr'] * self.gamma
|
||||
for group in self.optimizer.param_groups]
|
||||
return [group['lr']
|
||||
for group in self.optimizer.param_groups]
|
||||
|
||||
def _get_closed_form_lr(self):
|
||||
return [base_lr * self.gamma ** (self.last_epoch // self.num_steps_per_epoch)
|
||||
for base_lr in self.base_lrs]
|
Reference in New Issue
Block a user