mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 10:30:03 +00:00
Fixed docstring in colossalai (#171)
This commit is contained in:
@@ -124,6 +124,7 @@ class LossMetric(Metric):
|
||||
"""
|
||||
return self.last_step_loss
|
||||
|
||||
@staticmethod
|
||||
def is_better(a, b):
|
||||
return a < b
|
||||
|
||||
@@ -133,7 +134,7 @@ class LearningRateMetric(Metric):
|
||||
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
:param initial_lr: initial learning rate, defaults to 0.0
|
||||
:param initial_lr: Initial learning rate, defaults to 0.0
|
||||
:type initial_lr: float, optional
|
||||
"""
|
||||
|
||||
@@ -153,6 +154,7 @@ class LearningRateMetric(Metric):
|
||||
def get_accumulated_value(self):
|
||||
return self.lr
|
||||
|
||||
@staticmethod
|
||||
def is_better(a, b) -> bool:
|
||||
pass
|
||||
|
||||
@@ -163,8 +165,8 @@ class AccuracyMetric(Metric):
|
||||
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
:param accuracy_func: accuracy function for the classification task
|
||||
:type accuracy_func: typing.Callable
|
||||
:param accuracy_func: Accuracy function for the classification task
|
||||
:type accuracy_func: :class:`typing.Callable`
|
||||
"""
|
||||
|
||||
def __init__(self, epoch_only: bool, accuracy_func: Callable):
|
||||
@@ -186,8 +188,8 @@ class AccuracyMetric(Metric):
|
||||
and labels. It expects the output has logits and labels.
|
||||
|
||||
:param logits: The logits output of the model
|
||||
:param targets: real labels of the dataset
|
||||
:param batch_size: batch size of the task
|
||||
:param targets: Real labels of the dataset
|
||||
:param batch_size: Batch size of the task
|
||||
"""
|
||||
if isinstance(logits, (list, tuple)):
|
||||
logits = logits[0]
|
||||
@@ -211,6 +213,7 @@ class AccuracyMetric(Metric):
|
||||
self.accumulated_correct = all_reduce(self.accumulated_correct, ParallelMode.DATA)
|
||||
return (self.accumulated_correct / self.accumulated_sum).item()
|
||||
|
||||
@staticmethod
|
||||
def is_better(a, b) -> bool:
|
||||
return a > b
|
||||
|
||||
@@ -223,8 +226,6 @@ class MetricHook(BaseHook):
|
||||
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type priority: int
|
||||
:param trainer: Trainer attached with current hook
|
||||
:type trainer: Trainer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -245,8 +246,6 @@ class LossHook(MetricHook):
|
||||
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 0
|
||||
:type priority: int, optional
|
||||
:param trainer: Trainer attached with current hook
|
||||
:type trainer: Trainer
|
||||
"""
|
||||
|
||||
def __init__(self, priority: int = 0):
|
||||
@@ -288,8 +287,6 @@ class AccuracyHook(MetricHook):
|
||||
:type accuracy_func: typing.Callable
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 0
|
||||
:type priority: int, optional
|
||||
:param trainer: Trainer attached with current hook
|
||||
:type trainer: Trainer
|
||||
"""
|
||||
|
||||
def __init__(self, accuracy_func: Callable, priority: int = 0):
|
||||
@@ -319,8 +316,6 @@ class ThroughputMetric(Metric):
|
||||
|
||||
:param epoch_only: epoch only
|
||||
:type epoch_only: bool
|
||||
:param num_samples: number of samples
|
||||
:param time: time
|
||||
"""
|
||||
def __init__(self, epoch_only: bool):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
@@ -353,6 +348,7 @@ class ThroughputMetric(Metric):
|
||||
self.accumulated_num_samples = all_reduce(self.accumulated_num_samples, ParallelMode.DATA)
|
||||
return (self.accumulated_num_samples / (self.accumulated_used_time + 1e-12)).item()
|
||||
|
||||
@staticmethod
|
||||
def is_better(a, b) -> bool:
|
||||
pass
|
||||
|
||||
@@ -363,8 +359,6 @@ class ThroughputHook(MetricHook):
|
||||
|
||||
:param priority: priority of throughput hook, defaults to 10
|
||||
:type priority: int, optional
|
||||
:param trainer: Trainer attached with current hook
|
||||
:type trainer: Trainer
|
||||
"""
|
||||
def __init__(self, priority: int = 10):
|
||||
super().__init__(priority)
|
||||
|
Reference in New Issue
Block a user