mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 10:30:03 +00:00
[pipeline] refactor pipeline (#679)
* refactor pipeline---put runtime schedule into engine. * add type hint for schedule Optional[BaseSchedule] * preprocess schedule during engine initializing * infer pipeline schedule params from config
This commit is contained in:
@@ -312,7 +312,7 @@ class AccuracyHook(MetricHook):
|
||||
|
||||
def after_test_iter(self, trainer, logits, targets, *args):
|
||||
if self._is_stage_to_compute:
|
||||
batch_size = trainer.schedule.batch_size
|
||||
batch_size = trainer.engine.schedule.batch_size
|
||||
self.metric.update(logits, targets, batch_size)
|
||||
|
||||
|
||||
@@ -392,7 +392,7 @@ class ThroughputHook(MetricHook):
|
||||
|
||||
def after_train_iter(self, trainer, *args):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time())
|
||||
self.metric.update(trainer.engine.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time())
|
||||
|
||||
def before_test(self, trainer):
|
||||
if self._is_stage_to_compute:
|
||||
@@ -400,4 +400,4 @@ class ThroughputHook(MetricHook):
|
||||
|
||||
def after_test_iter(self, trainer, *args):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())
|
||||
self.metric.update(trainer.engine.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())
|
||||
|
Reference in New Issue
Block a user