mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 18:09:06 +00:00
Fixed docstring in colossalai (#171)
This commit is contained in:
@@ -12,7 +12,6 @@ class BaseHook(ABC):
|
||||
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type priority: int
|
||||
:param trainer: Trainer attached with current hook
|
||||
"""
|
||||
|
||||
def __init__(self, priority: int) -> None:
|
||||
@@ -41,6 +40,8 @@ class BaseHook(ABC):
|
||||
def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
|
||||
"""Actions after running a training iteration.
|
||||
|
||||
:param trainer: Trainer which is using this hook
|
||||
:type trainer: :class:`Trainer`
|
||||
:param output: Output of the model
|
||||
:type output: torch.Tensor
|
||||
:param label: Labels of the input data
|
||||
@@ -88,6 +89,8 @@ class BaseHook(ABC):
|
||||
def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
|
||||
"""Actions after running a testing iteration.
|
||||
|
||||
:param trainer: Trainer which is using this hook
|
||||
:type trainer: :class:`Trainer`
|
||||
:param output: Output of the model
|
||||
:type output: Tensor
|
||||
:param label: Labels of the input data
|
||||
@@ -100,6 +103,8 @@ class BaseHook(ABC):
|
||||
def init_runner_states(self, trainer, key, val):
|
||||
"""Initializes trainer's state.
|
||||
|
||||
:param trainer: Trainer which is using this hook
|
||||
:type trainer: :class:`Trainer`
|
||||
:param key: Key of reseting state
|
||||
:param val: Value of reseting state
|
||||
"""
|
||||
|
@@ -24,7 +24,6 @@ class SaveCheckpointHook(BaseHook):
|
||||
:type suffix: str, optional
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
|
||||
:type priority: int, optional
|
||||
:param trainer: Trainer attached with current hook
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -84,7 +83,6 @@ class LoadCheckpointHook(BaseHook):
|
||||
:type suffix: str, optional
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 0
|
||||
:type priority: int, optional
|
||||
:param trainer: Trainer attached with current hook
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@@ -25,15 +25,15 @@ def _format_number(val, prec=5):
|
||||
|
||||
|
||||
class LogByEpochHook(BaseHook):
|
||||
"""hook to log by epoch
|
||||
"""Hook to log by epoch
|
||||
|
||||
:param logger: logger for the log
|
||||
:param logger: Logger for the log
|
||||
:param interval: Recording interval, defaults to 1
|
||||
:type interval: int, optional
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 1
|
||||
:type priority: int, optional
|
||||
:param trainer: Trainer attached with current hook
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
logger,
|
||||
interval: int = 1,
|
||||
@@ -48,12 +48,12 @@ class LogByEpochHook(BaseHook):
|
||||
|
||||
@HOOKS.register_module
|
||||
class LogMetricByStepHook(BaseHook):
|
||||
"""hook to log metric by step
|
||||
"""Hook to log metric by step
|
||||
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
|
||||
:type priority: int, optional
|
||||
:param trainer: Trainer attached with current hook
|
||||
"""
|
||||
|
||||
def __init__(self, priority: int = 10):
|
||||
super().__init__(priority)
|
||||
|
||||
@@ -62,7 +62,7 @@ class LogMetricByStepHook(BaseHook):
|
||||
for metric_name, metric_calculator in trainer.states['metrics']['train'].items():
|
||||
trainer.states['step_metrics'][metric_name.lower()] = \
|
||||
f'{_format_number(metric_calculator.get_last_step_value())}'
|
||||
|
||||
|
||||
def after_test_iter(self, trainer, *args):
|
||||
trainer.states['step_metrics'] = dict()
|
||||
for metric_name, metric_calculator in trainer.states['metrics']['test'].items():
|
||||
@@ -72,15 +72,13 @@ class LogMetricByStepHook(BaseHook):
|
||||
|
||||
@HOOKS.register_module
|
||||
class LogMetricByEpochHook(LogByEpochHook):
|
||||
"""Specialized Hook to record the metric to log.
|
||||
"""Specialized hook to record the metric to log.
|
||||
|
||||
:param logger: logger for the log
|
||||
:param logger: Logger for the log
|
||||
:param interval: Recording interval, defaults to 1
|
||||
:type interval: int, optional
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
|
||||
:type priority: int, optional
|
||||
:param trainer: Trainer attached with current hook
|
||||
:param mode: Mode of metrics, 'train' and 'test'
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -116,19 +114,16 @@ class LogMetricByEpochHook(LogByEpochHook):
|
||||
|
||||
@HOOKS.register_module
|
||||
class TensorboardHook(BaseHook):
|
||||
"""Specialized Hook to record the metric to Tensorboard.
|
||||
"""Specialized hook to record the metric to Tensorboard.
|
||||
|
||||
:param log_dir: Directory of log
|
||||
:type log_dir: str
|
||||
:param ranks: ranks of processors
|
||||
:param ranks: Ranks of processors
|
||||
:type ranks: typing.List
|
||||
:param parallel_mode: Parallel mode, defaults to colossalai.context.parallel_mode.ParallelMode.GLOBAL
|
||||
:type parallel_mode: colossalai.context.parallel_mode.ParallelMode, optional
|
||||
:type parallel_mode: :class:`colossalai.context.parallel_mode.ParallelMode`, optional
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
|
||||
:type priority: int, optional
|
||||
:param trainer: Trainer attached with current hook
|
||||
:param mode: Mode of metrics, 'train' and 'test'
|
||||
:type mode: str
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -203,12 +198,12 @@ class TensorboardHook(BaseHook):
|
||||
|
||||
@HOOKS.register_module
|
||||
class LogTimingByEpochHook(LogByEpochHook):
|
||||
"""Specialized Hook to write timing record to log.
|
||||
"""Specialized hook to write timing record to log.
|
||||
|
||||
:param timer: Timer for the hook
|
||||
:type timer: colossalai.utils.MultiTimer
|
||||
:type timer: :class:`colossalai.utils.MultiTimer`
|
||||
:param logger: Logger for the log
|
||||
:type logger: colossalai.logging.DistributedLogger
|
||||
:type logger: :class:`colossalai.logging.DistributedLogger`
|
||||
:param interval: Recording interval, defaults to 1
|
||||
:type interval: int, optional
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
|
||||
@@ -217,9 +212,8 @@ class LogTimingByEpochHook(LogByEpochHook):
|
||||
:type log_eval: bool, optional
|
||||
:param ignore_num_train_steps: Number of training steps to ignore, defaults to 0
|
||||
:type ignore_num_train_steps: int, optional
|
||||
:param mode: Mode of metrics, 'train' and 'test'
|
||||
:param trainer: Trainer attached with current hook
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
timer: MultiTimer,
|
||||
logger: DistributedLogger,
|
||||
@@ -285,12 +279,13 @@ class LogMemoryByEpochHook(LogByEpochHook):
|
||||
:param log_eval: Whether writes in evaluation, defaults to True
|
||||
:type log_eval: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
logger: DistributedLogger,
|
||||
interval: int = 1,
|
||||
priority: int = 10,
|
||||
log_eval: bool = True,
|
||||
report_cpu: bool = False, # no reference
|
||||
report_cpu: bool = False, # no reference
|
||||
) -> None:
|
||||
super().__init__(logger=logger, interval=interval, priority=priority)
|
||||
self._log_eval = log_eval
|
||||
|
@@ -15,7 +15,6 @@ class LRSchedulerHook(MetricHook):
|
||||
:type store_lr_in_state: bool, optional
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 1
|
||||
:type priority: int, optional
|
||||
:param trainer: Trainer attached with current hook
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
|
@@ -124,6 +124,7 @@ class LossMetric(Metric):
|
||||
"""
|
||||
return self.last_step_loss
|
||||
|
||||
@staticmethod
|
||||
def is_better(a, b):
|
||||
return a < b
|
||||
|
||||
@@ -133,7 +134,7 @@ class LearningRateMetric(Metric):
|
||||
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
:param initial_lr: initial learning rate, defaults to 0.0
|
||||
:param initial_lr: Initial learning rate, defaults to 0.0
|
||||
:type initial_lr: float, optional
|
||||
"""
|
||||
|
||||
@@ -153,6 +154,7 @@ class LearningRateMetric(Metric):
|
||||
def get_accumulated_value(self):
|
||||
return self.lr
|
||||
|
||||
@staticmethod
|
||||
def is_better(a, b) -> bool:
|
||||
pass
|
||||
|
||||
@@ -163,8 +165,8 @@ class AccuracyMetric(Metric):
|
||||
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
:param accuracy_func: accuracy function for the classification task
|
||||
:type accuracy_func: typing.Callable
|
||||
:param accuracy_func: Accuracy function for the classification task
|
||||
:type accuracy_func: :class:`typing.Callable`
|
||||
"""
|
||||
|
||||
def __init__(self, epoch_only: bool, accuracy_func: Callable):
|
||||
@@ -186,8 +188,8 @@ class AccuracyMetric(Metric):
|
||||
and labels. It expects the output has logits and labels.
|
||||
|
||||
:param logits: The logits output of the model
|
||||
:param targets: real labels of the dataset
|
||||
:param batch_size: batch size of the task
|
||||
:param targets: Real labels of the dataset
|
||||
:param batch_size: Batch size of the task
|
||||
"""
|
||||
if isinstance(logits, (list, tuple)):
|
||||
logits = logits[0]
|
||||
@@ -211,6 +213,7 @@ class AccuracyMetric(Metric):
|
||||
self.accumulated_correct = all_reduce(self.accumulated_correct, ParallelMode.DATA)
|
||||
return (self.accumulated_correct / self.accumulated_sum).item()
|
||||
|
||||
@staticmethod
|
||||
def is_better(a, b) -> bool:
|
||||
return a > b
|
||||
|
||||
@@ -223,8 +226,6 @@ class MetricHook(BaseHook):
|
||||
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type priority: int
|
||||
:param trainer: Trainer attached with current hook
|
||||
:type trainer: Trainer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -245,8 +246,6 @@ class LossHook(MetricHook):
|
||||
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 0
|
||||
:type priority: int, optional
|
||||
:param trainer: Trainer attached with current hook
|
||||
:type trainer: Trainer
|
||||
"""
|
||||
|
||||
def __init__(self, priority: int = 0):
|
||||
@@ -288,8 +287,6 @@ class AccuracyHook(MetricHook):
|
||||
:type accuracy_func: typing.Callable
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 0
|
||||
:type priority: int, optional
|
||||
:param trainer: Trainer attached with current hook
|
||||
:type trainer: Trainer
|
||||
"""
|
||||
|
||||
def __init__(self, accuracy_func: Callable, priority: int = 0):
|
||||
@@ -319,8 +316,6 @@ class ThroughputMetric(Metric):
|
||||
|
||||
:param epoch_only: epoch only
|
||||
:type epoch_only: bool
|
||||
:param num_samples: number of samples
|
||||
:param time: time
|
||||
"""
|
||||
def __init__(self, epoch_only: bool):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
@@ -353,6 +348,7 @@ class ThroughputMetric(Metric):
|
||||
self.accumulated_num_samples = all_reduce(self.accumulated_num_samples, ParallelMode.DATA)
|
||||
return (self.accumulated_num_samples / (self.accumulated_used_time + 1e-12)).item()
|
||||
|
||||
@staticmethod
|
||||
def is_better(a, b) -> bool:
|
||||
pass
|
||||
|
||||
@@ -363,8 +359,6 @@ class ThroughputHook(MetricHook):
|
||||
|
||||
:param priority: priority of throughput hook, defaults to 10
|
||||
:type priority: int, optional
|
||||
:param trainer: Trainer attached with current hook
|
||||
:type trainer: Trainer
|
||||
"""
|
||||
def __init__(self, priority: int = 10):
|
||||
super().__init__(priority)
|
||||
|
Reference in New Issue
Block a user