mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 09:29:47 +00:00
@@ -25,6 +25,7 @@ class Metric(ABC):
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
"""
|
||||
|
||||
def __init__(self, epoch_only: bool):
|
||||
# is the metric only read for the full epoch
|
||||
self._epoch_only = epoch_only
|
||||
@@ -82,6 +83,7 @@ class LossMetric(Metric):
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
"""
|
||||
|
||||
def __init__(self, epoch_only):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.last_step_loss = torch.zeros(1, device=get_current_device())
|
||||
@@ -132,6 +134,7 @@ class LearningRateMetric(Metric):
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
"""
|
||||
|
||||
def __init__(self, epoch_only: bool, initial_lr: float = 0.):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.lr = initial_lr
|
||||
@@ -159,6 +162,7 @@ class AccuracyMetric(Metric):
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
"""
|
||||
|
||||
def __init__(self, epoch_only: bool, accuracy_func: Callable):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.acc = accuracy_func
|
||||
@@ -217,6 +221,7 @@ class MetricHook(BaseHook):
|
||||
:type trainer: Trainer
|
||||
:type priority: int
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
priority: int,
|
||||
@@ -238,6 +243,7 @@ class LossHook(MetricHook):
|
||||
:type trainer: Trainer
|
||||
:type priority: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, priority: int = 0):
|
||||
super().__init__(priority)
|
||||
|
||||
@@ -278,6 +284,7 @@ class AccuracyHook(MetricHook):
|
||||
:type trainer: Trainer
|
||||
:type priority: int
|
||||
"""
|
||||
|
||||
def __init__(self, accuracy_func: Callable, priority: int = 0):
|
||||
super().__init__(priority)
|
||||
self.accuracy_func = accuracy_func
|
||||
@@ -351,13 +358,17 @@ class ThroughputHook(MetricHook):
|
||||
trainer.states['metrics']['test']['Throughput'] = self.metric
|
||||
|
||||
def before_train_epoch(self, trainer):
|
||||
self.metric.reset()
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.reset()
|
||||
|
||||
def after_train_iter(self, trainer, *args):
|
||||
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time())
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time())
|
||||
|
||||
def before_test(self, trainer):
|
||||
self.metric.reset()
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.reset()
|
||||
|
||||
def after_test_iter(self, trainer, *args):
|
||||
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.update(trainer.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())
|
||||
|
Reference in New Issue
Block a user