fix layers/schedule for hybrid parallelization (#111) (#112)

This commit is contained in:
ver217
2022-01-04 20:52:31 +08:00
committed by GitHub
parent f03bcb359b
commit 7904baf6e1
6 changed files with 44 additions and 18 deletions

View File

@@ -25,6 +25,7 @@ class Metric(ABC):
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
"""
def __init__(self, epoch_only: bool):
# is the metric only read for the full epoch
self._epoch_only = epoch_only
@@ -82,6 +83,7 @@ class LossMetric(Metric):
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
"""
def __init__(self, epoch_only):
super().__init__(epoch_only=epoch_only)
self.last_step_loss = torch.zeros(1, device=get_current_device())
@@ -132,6 +134,7 @@ class LearningRateMetric(Metric):
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
"""
def __init__(self, epoch_only: bool, initial_lr: float = 0.):
super().__init__(epoch_only=epoch_only)
self.lr = initial_lr
@@ -159,6 +162,7 @@ class AccuracyMetric(Metric):
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
"""
def __init__(self, epoch_only: bool, accuracy_func: Callable):
super().__init__(epoch_only=epoch_only)
self.acc = accuracy_func
@@ -217,6 +221,7 @@ class MetricHook(BaseHook):
:type trainer: Trainer
:type priority: int
"""
def __init__(
self,
priority: int,
@@ -238,6 +243,7 @@ class LossHook(MetricHook):
:type trainer: Trainer
:type priority: int, optional
"""
def __init__(self, priority: int = 0):
super().__init__(priority)
@@ -278,6 +284,7 @@ class AccuracyHook(MetricHook):
:type trainer: Trainer
:type priority: int
"""
def __init__(self, accuracy_func: Callable, priority: int = 0):
super().__init__(priority)
self.accuracy_func = accuracy_func
@@ -351,13 +358,17 @@ class ThroughputHook(MetricHook):
trainer.states['metrics']['test']['Throughput'] = self.metric
def before_train_epoch(self, trainer):
self.metric.reset()
if self._is_stage_to_compute:
self.metric.reset()
def after_train_iter(self, trainer, *args):
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time())
if self._is_stage_to_compute:
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time())
def before_test(self, trainer):
self.metric.reset()
if self._is_stage_to_compute:
self.metric.reset()
def after_test_iter(self, trainer, *args):
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())
if self._is_stage_to_compute:
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())