[legacy] move trainer to legacy (#4545)

* [legacy] move trainer to legacy

* [doc] update docs related to trainer

* [test] ignore legacy test
This commit is contained in:
Hongxin Liu
2023-08-31 13:51:28 +08:00
parent 807e01a4ba
commit 89fe027787
32 changed files with 63 additions and 153 deletions

View File

View File

@@ -1,14 +1,13 @@
from typing import Union, List, Any
from typing import Any, List, Union
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from colossalai.engine import Engine
from colossalai.legacy.trainer.hooks import BaseHook
from colossalai.logging import DistributedLogger
from colossalai.utils import MultiTimer
from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
from colossalai.trainer.hooks import BaseHook
from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0
class Trainer:

View File

@@ -1,7 +1,12 @@
from ._base_hook import BaseHook
from ._checkpoint_hook import SaveCheckpointHook
from ._log_hook import (LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook, LogTimingByEpochHook,
TensorboardHook)
from ._log_hook import (
LogMemoryByEpochHook,
LogMetricByEpochHook,
LogMetricByStepHook,
LogTimingByEpochHook,
TensorboardHook,
)
from ._lr_scheduler_hook import LRSchedulerHook
from ._metric_hook import AccuracyHook, LossHook, MetricHook, ThroughputHook

View File

@@ -1,11 +1,12 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from colossalai.logging import get_dist_logger
from colossalai.legacy.trainer.hooks import BaseHook
from colossalai.logging import get_dist_logger
from colossalai.registry import HOOKS
from colossalai.trainer.hooks import BaseHook
from colossalai.utils.checkpointing import save_checkpoint
from ._lr_scheduler_hook import LRSchedulerHook

View File

@@ -3,17 +3,17 @@
import os
import os.path as osp
from typing import List
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import HOOKS
from colossalai.legacy.trainer.hooks._metric_hook import ThroughputMetric
from colossalai.logging import DistributedLogger
from colossalai.utils import report_memory_usage, is_dp_rank_0, \
is_tp_rank_0, is_no_pp_or_last_stage, MultiTimer
from colossalai.registry import HOOKS
from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0, report_memory_usage
from ._base_hook import BaseHook
from ._commons_ import _format_number
from colossalai.trainer.hooks._metric_hook import ThroughputMetric
class LogByEpochHook(BaseHook):

View File

@@ -1,6 +1,7 @@
from colossalai.registry import HOOKS
from torch import Tensor
from colossalai.registry import HOOKS
from ._metric_hook import LearningRateMetric, MetricHook

View File

@@ -6,6 +6,7 @@ from typing import Callable
import torch
import torch.distributed as dist
from colossalai.communication import all_reduce
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
@@ -19,8 +20,8 @@ from ._commons_ import _format_number
class Metric(ABC):
"""A basic class of metric collectors. It collects a specific
metric during training or evaluation and would always be used with
:class:`MetricHook` to help it update its states and show the
metric. So please use corresponding hook class to make the metric
:class:`MetricHook` to help it update its states and show the
metric. So please use corresponding hook class to make the metric
collector works.
Args:
@@ -220,9 +221,9 @@ class AccuracyMetric(Metric):
class MetricHook(BaseHook):
"""Specialized hook classes for :class:`Metric`.
Some help metric collectors initialize, reset and
update their states. Others are used to display and
"""Specialized hook classes for :class:`Metric`.
Some help metric collectors initialize, reset and
update their states. Others are used to display and
record the metric.
Args: