mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 10:30:03 +00:00
Update layer integration documentations (#108)
Update the documentations of layer integration Update _log_hook.py Update _operation.py
This commit is contained in:
@@ -133,6 +133,8 @@ class LearningRateMetric(Metric):
|
||||
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
:param initial_lr: initial learning rate, defaults to 0.0
|
||||
:type initial_lr: float, optional
|
||||
"""
|
||||
|
||||
def __init__(self, epoch_only: bool, initial_lr: float = 0.):
|
||||
@@ -161,6 +163,8 @@ class AccuracyMetric(Metric):
|
||||
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
:param accuracy_func: accuracy function for the classification task
|
||||
:type accuracy_func: typing.Callable
|
||||
"""
|
||||
|
||||
def __init__(self, epoch_only: bool, accuracy_func: Callable):
|
||||
@@ -182,7 +186,8 @@ class AccuracyMetric(Metric):
|
||||
and labels. It expects the output has logits and labels.
|
||||
|
||||
:param logits: The logits output of the model
|
||||
:param label: The labels of the input data
|
||||
:param targets: real labels of the dataset
|
||||
:param batch_size: batch size of the task
|
||||
"""
|
||||
if isinstance(logits, (list, tuple)):
|
||||
logits = logits[0]
|
||||
@@ -216,10 +221,10 @@ class MetricHook(BaseHook):
|
||||
update their states. Others are used to display and
|
||||
record the metric.
|
||||
|
||||
:param trainer: Trainer attached with current hook
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type trainer: Trainer
|
||||
:type priority: int
|
||||
:param trainer: Trainer attached with current hook
|
||||
:type trainer: Trainer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -238,10 +243,10 @@ class MetricHook(BaseHook):
|
||||
class LossHook(MetricHook):
|
||||
"""Specialized hook class for :class:`Loss`.
|
||||
|
||||
:param trainer: Trainer attached with current hook
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type trainer: Trainer
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 0
|
||||
:type priority: int, optional
|
||||
:param trainer: Trainer attached with current hook
|
||||
:type trainer: Trainer
|
||||
"""
|
||||
|
||||
def __init__(self, priority: int = 0):
|
||||
@@ -279,10 +284,12 @@ class LossHook(MetricHook):
|
||||
class AccuracyHook(MetricHook):
|
||||
"""Specialized hook class for :class:`Accuracy`.
|
||||
|
||||
:param accuracy_func: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type accuracy_func: typing.Callable
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 0
|
||||
:type priority: int, optional
|
||||
:param trainer: Trainer attached with current hook
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type trainer: Trainer
|
||||
:type priority: int
|
||||
"""
|
||||
|
||||
def __init__(self, accuracy_func: Callable, priority: int = 0):
|
||||
@@ -308,6 +315,13 @@ class AccuracyHook(MetricHook):
|
||||
|
||||
|
||||
class ThroughputMetric(Metric):
|
||||
"""Metric for :class:`Throughput`.
|
||||
|
||||
:param epoch_only: epoch only
|
||||
:type epoch_only: bool
|
||||
:param num_samples: number of samples
|
||||
:param time: time
|
||||
"""
|
||||
def __init__(self, epoch_only: bool):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.accumulated_num_samples = torch.zeros(1, device=get_current_device())
|
||||
@@ -345,6 +359,13 @@ class ThroughputMetric(Metric):
|
||||
|
||||
@HOOKS.register_module
|
||||
class ThroughputHook(MetricHook):
|
||||
"""Specialized hook class for :class:`Throughput`.
|
||||
|
||||
:param priority: priority of throughput hook, defaults to 10
|
||||
:type priority: int, optional
|
||||
:param trainer: Trainer attached with current hook
|
||||
:type trainer: Trainer
|
||||
"""
|
||||
def __init__(self, priority: int = 10):
|
||||
super().__init__(priority)
|
||||
|
||||
|
Reference in New Issue
Block a user