[pipeline] refactor pipeline (#679)

* refactor pipeline---put runtime schedule into engine.

* add type hint for schedule Optional[BaseSchedule]

* preprocess schedule during engine initializing

* infer pipeline schedule params from config
This commit is contained in:
YuliangLiu0306
2022-04-07 15:54:14 +08:00
committed by GitHub
parent eace69387d
commit 0ed7042f42
4 changed files with 36 additions and 9 deletions

View File

@@ -312,7 +312,7 @@ class AccuracyHook(MetricHook):
def after_test_iter(self, trainer, logits, targets, *args):
if self._is_stage_to_compute:
batch_size = trainer.schedule.batch_size
batch_size = trainer.engine.schedule.batch_size
self.metric.update(logits, targets, batch_size)
@@ -392,7 +392,7 @@ class ThroughputHook(MetricHook):
def after_train_iter(self, trainer, *args):
if self._is_stage_to_compute:
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time())
self.metric.update(trainer.engine.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time())
def before_test(self, trainer):
if self._is_stage_to_compute:
@@ -400,4 +400,4 @@ class ThroughputHook(MetricHook):
def after_test_iter(self, trainer, *args):
if self._is_stage_to_compute:
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())
self.metric.update(trainer.engine.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())