Layer integration (#83)

* integrated parallel layers for ease of building models

* integrated 2.5d layers

* cleaned codes and unit tests

* added log metric by step hook; updated imagenet benchmark; fixed some bugs

* reworked initialization; cleaned codes

Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
アマデウス
2021-12-27 15:04:32 +08:00
committed by GitHub
parent 5c3843dc98
commit 0fedef4f3c
118 changed files with 4941 additions and 8116 deletions

View File

@@ -1,9 +1,7 @@
from colossalai.registry import HOOKS
from torch import Tensor
from colossalai.builder import build_lr_scheduler
from colossalai.registry import HOOKS
from ._metric_hook import MetricHook
from ..metric import LearningRate
from ._metric_hook import LearningRateMetric, MetricHook
@HOOKS.register_module
@@ -19,28 +17,28 @@ class LRSchedulerHook(MetricHook):
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type priority: int, optional
"""
def __init__(self,
lr_scheduler,
by_epoch: bool,
store_lr_in_state: bool = True,
priority: int = 1,
):
def __init__(
self,
lr_scheduler,
by_epoch: bool,
store_lr_in_state: bool = True,
priority: int = 1,
):
super().__init__(priority=priority)
self.by_epoch = by_epoch
self.lr_scheduler = lr_scheduler
self.store_lr_in_state = store_lr_in_state
def after_hook_is_attached(self, trainer):
trainer.states['metrics']['train']['lr'] = LearningRate(epoch_only=self.by_epoch,
initial_lr=self.lr_scheduler.get_last_lr()[0])
trainer.states['metrics']['train']['LR'] = LearningRateMetric(epoch_only=self.by_epoch,
initial_lr=self.lr_scheduler.get_last_lr()[0])
def after_train_epoch(self, trainer):
if self.by_epoch:
self.lr_scheduler.step()
trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_last_lr()[0])
trainer.states['metrics']['train']['LR'].update(self.lr_scheduler.get_last_lr()[0])
def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
if not self.by_epoch:
self.lr_scheduler.step()
trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_last_lr()[0])
trainer.states['metrics']['train']['LR'].update(self.lr_scheduler.get_last_lr()[0])