Fixed docstring in colossalai (#171)

This commit is contained in:
HELSON
2022-01-21 10:44:30 +08:00
committed by GitHub
parent e2089c5c15
commit 0f8c7f9804
77 changed files with 983 additions and 603 deletions

View File

@@ -12,7 +12,6 @@ class BaseHook(ABC):
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type priority: int
:param trainer: Trainer attached with current hook
"""
def __init__(self, priority: int) -> None:
@@ -41,6 +40,8 @@ class BaseHook(ABC):
def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
"""Actions after running a training iteration.
:param trainer: Trainer which is using this hook
:type trainer: :class:`Trainer`
:param output: Output of the model
:type output: torch.Tensor
:param label: Labels of the input data
@@ -88,6 +89,8 @@ class BaseHook(ABC):
def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
"""Actions after running a testing iteration.
:param trainer: Trainer which is using this hook
:type trainer: :class:`Trainer`
:param output: Output of the model
:type output: Tensor
:param label: Labels of the input data
@@ -100,6 +103,8 @@ class BaseHook(ABC):
def init_runner_states(self, trainer, key, val):
"""Initializes trainer's state.
:param trainer: Trainer which is using this hook
:type trainer: :class:`Trainer`
:param key: Key of reseting state
:param val: Value of reseting state
"""

View File

@@ -24,7 +24,6 @@ class SaveCheckpointHook(BaseHook):
:type suffix: str, optional
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
:type priority: int, optional
:param trainer: Trainer attached with current hook
"""
def __init__(self,
@@ -84,7 +83,6 @@ class LoadCheckpointHook(BaseHook):
:type suffix: str, optional
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 0
:type priority: int, optional
:param trainer: Trainer attached with current hook
"""
def __init__(self,

View File

@@ -25,15 +25,15 @@ def _format_number(val, prec=5):
class LogByEpochHook(BaseHook):
"""hook to log by epoch
"""Hook to log by epoch
:param logger: logger for the log
:param logger: Logger for the log
:param interval: Recording interval, defaults to 1
:type interval: int, optional
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 1
:type priority: int, optional
:param trainer: Trainer attached with current hook
"""
def __init__(self,
logger,
interval: int = 1,
@@ -48,12 +48,12 @@ class LogByEpochHook(BaseHook):
@HOOKS.register_module
class LogMetricByStepHook(BaseHook):
"""hook to log metric by step
"""Hook to log metric by step
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
:type priority: int, optional
:param trainer: Trainer attached with current hook
"""
def __init__(self, priority: int = 10):
super().__init__(priority)
@@ -62,7 +62,7 @@ class LogMetricByStepHook(BaseHook):
for metric_name, metric_calculator in trainer.states['metrics']['train'].items():
trainer.states['step_metrics'][metric_name.lower()] = \
f'{_format_number(metric_calculator.get_last_step_value())}'
def after_test_iter(self, trainer, *args):
trainer.states['step_metrics'] = dict()
for metric_name, metric_calculator in trainer.states['metrics']['test'].items():
@@ -72,15 +72,13 @@ class LogMetricByStepHook(BaseHook):
@HOOKS.register_module
class LogMetricByEpochHook(LogByEpochHook):
"""Specialized Hook to record the metric to log.
"""Specialized hook to record the metric to log.
:param logger: logger for the log
:param logger: Logger for the log
:param interval: Recording interval, defaults to 1
:type interval: int, optional
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
:type priority: int, optional
:param trainer: Trainer attached with current hook
:param mode: Mode of metrics, 'train' and 'test'
"""
def __init__(self,
@@ -116,19 +114,16 @@ class LogMetricByEpochHook(LogByEpochHook):
@HOOKS.register_module
class TensorboardHook(BaseHook):
"""Specialized Hook to record the metric to Tensorboard.
"""Specialized hook to record the metric to Tensorboard.
:param log_dir: Directory of log
:type log_dir: str
:param ranks: ranks of processors
:param ranks: Ranks of processors
:type ranks: typing.List
:param parallel_mode: Parallel mode, defaults to colossalai.context.parallel_mode.ParallelMode.GLOBAL
:type parallel_mode: colossalai.context.parallel_mode.ParallelMode, optional
:type parallel_mode: :class:`colossalai.context.parallel_mode.ParallelMode`, optional
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
:type priority: int, optional
:param trainer: Trainer attached with current hook
:param mode: Mode of metrics, 'train' and 'test'
:type mode: str
"""
def __init__(self,
@@ -203,12 +198,12 @@ class TensorboardHook(BaseHook):
@HOOKS.register_module
class LogTimingByEpochHook(LogByEpochHook):
"""Specialized Hook to write timing record to log.
"""Specialized hook to write timing record to log.
:param timer: Timer for the hook
:type timer: colossalai.utils.MultiTimer
:type timer: :class:`colossalai.utils.MultiTimer`
:param logger: Logger for the log
:type logger: colossalai.logging.DistributedLogger
:type logger: :class:`colossalai.logging.DistributedLogger`
:param interval: Recording interval, defaults to 1
:type interval: int, optional
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
@@ -217,9 +212,8 @@ class LogTimingByEpochHook(LogByEpochHook):
:type log_eval: bool, optional
:param ignore_num_train_steps: Number of training steps to ignore, defaults to 0
:type ignore_num_train_steps: int, optional
:param mode: Mode of metrics, 'train' and 'test'
:param trainer: Trainer attached with current hook
"""
def __init__(self,
timer: MultiTimer,
logger: DistributedLogger,
@@ -285,12 +279,13 @@ class LogMemoryByEpochHook(LogByEpochHook):
:param log_eval: Whether writes in evaluation, defaults to True
:type log_eval: bool, optional
"""
def __init__(self,
logger: DistributedLogger,
interval: int = 1,
priority: int = 10,
log_eval: bool = True,
report_cpu: bool = False, # no reference
report_cpu: bool = False, # no reference
) -> None:
super().__init__(logger=logger, interval=interval, priority=priority)
self._log_eval = log_eval

View File

@@ -15,7 +15,6 @@ class LRSchedulerHook(MetricHook):
:type store_lr_in_state: bool, optional
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 1
:type priority: int, optional
:param trainer: Trainer attached with current hook
"""
def __init__(
self,

View File

@@ -124,6 +124,7 @@ class LossMetric(Metric):
"""
return self.last_step_loss
@staticmethod
def is_better(a, b):
return a < b
@@ -133,7 +134,7 @@ class LearningRateMetric(Metric):
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
:param initial_lr: initial learning rate, defaults to 0.0
:param initial_lr: Initial learning rate, defaults to 0.0
:type initial_lr: float, optional
"""
@@ -153,6 +154,7 @@ class LearningRateMetric(Metric):
def get_accumulated_value(self):
return self.lr
@staticmethod
def is_better(a, b) -> bool:
pass
@@ -163,8 +165,8 @@ class AccuracyMetric(Metric):
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
:param accuracy_func: accuracy function for the classification task
:type accuracy_func: typing.Callable
:param accuracy_func: Accuracy function for the classification task
:type accuracy_func: :class:`typing.Callable`
"""
def __init__(self, epoch_only: bool, accuracy_func: Callable):
@@ -186,8 +188,8 @@ class AccuracyMetric(Metric):
and labels. It expects the output has logits and labels.
:param logits: The logits output of the model
:param targets: real labels of the dataset
:param batch_size: batch size of the task
:param targets: Real labels of the dataset
:param batch_size: Batch size of the task
"""
if isinstance(logits, (list, tuple)):
logits = logits[0]
@@ -211,6 +213,7 @@ class AccuracyMetric(Metric):
self.accumulated_correct = all_reduce(self.accumulated_correct, ParallelMode.DATA)
return (self.accumulated_correct / self.accumulated_sum).item()
@staticmethod
def is_better(a, b) -> bool:
return a > b
@@ -223,8 +226,6 @@ class MetricHook(BaseHook):
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type priority: int
:param trainer: Trainer attached with current hook
:type trainer: Trainer
"""
def __init__(
@@ -245,8 +246,6 @@ class LossHook(MetricHook):
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 0
:type priority: int, optional
:param trainer: Trainer attached with current hook
:type trainer: Trainer
"""
def __init__(self, priority: int = 0):
@@ -288,8 +287,6 @@ class AccuracyHook(MetricHook):
:type accuracy_func: typing.Callable
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 0
:type priority: int, optional
:param trainer: Trainer attached with current hook
:type trainer: Trainer
"""
def __init__(self, accuracy_func: Callable, priority: int = 0):
@@ -319,8 +316,6 @@ class ThroughputMetric(Metric):
:param epoch_only: epoch only
:type epoch_only: bool
:param num_samples: number of samples
:param time: time
"""
def __init__(self, epoch_only: bool):
super().__init__(epoch_only=epoch_only)
@@ -353,6 +348,7 @@ class ThroughputMetric(Metric):
self.accumulated_num_samples = all_reduce(self.accumulated_num_samples, ParallelMode.DATA)
return (self.accumulated_num_samples / (self.accumulated_used_time + 1e-12)).item()
@staticmethod
def is_better(a, b) -> bool:
pass
@@ -363,8 +359,6 @@ class ThroughputHook(MetricHook):
:param priority: priority of throughput hook, defaults to 10
:type priority: int, optional
:param trainer: Trainer attached with current hook
:type trainer: Trainer
"""
def __init__(self, priority: int = 10):
super().__init__(priority)