mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[legacy] move trainer to legacy (#4545)
* [legacy] move trainer to legacy * [doc] update docs related to trainer * [test] ignore legacy test
This commit is contained in:
0
colossalai/legacy/__init__.py
Normal file
0
colossalai/legacy/__init__.py
Normal file
@@ -1,14 +1,13 @@
|
||||
from typing import Union, List, Any
|
||||
from typing import Any, List, Union
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.legacy.trainer.hooks import BaseHook
|
||||
from colossalai.logging import DistributedLogger
|
||||
from colossalai.utils import MultiTimer
|
||||
from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
|
||||
from colossalai.trainer.hooks import BaseHook
|
||||
from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0
|
||||
|
||||
|
||||
class Trainer:
|
@@ -1,7 +1,12 @@
|
||||
from ._base_hook import BaseHook
|
||||
from ._checkpoint_hook import SaveCheckpointHook
|
||||
from ._log_hook import (LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook, LogTimingByEpochHook,
|
||||
TensorboardHook)
|
||||
from ._log_hook import (
|
||||
LogMemoryByEpochHook,
|
||||
LogMetricByEpochHook,
|
||||
LogMetricByStepHook,
|
||||
LogTimingByEpochHook,
|
||||
TensorboardHook,
|
||||
)
|
||||
from ._lr_scheduler_hook import LRSchedulerHook
|
||||
from ._metric_hook import AccuracyHook, LossHook, MetricHook, ThroughputHook
|
||||
|
@@ -1,11 +1,12 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
import torch
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from colossalai.legacy.trainer.hooks import BaseHook
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.registry import HOOKS
|
||||
from colossalai.trainer.hooks import BaseHook
|
||||
from colossalai.utils.checkpointing import save_checkpoint
|
||||
|
||||
from ._lr_scheduler_hook import LRSchedulerHook
|
||||
|
||||
|
@@ -3,17 +3,17 @@
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
from typing import List
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import HOOKS
|
||||
from colossalai.legacy.trainer.hooks._metric_hook import ThroughputMetric
|
||||
from colossalai.logging import DistributedLogger
|
||||
from colossalai.utils import report_memory_usage, is_dp_rank_0, \
|
||||
is_tp_rank_0, is_no_pp_or_last_stage, MultiTimer
|
||||
from colossalai.registry import HOOKS
|
||||
from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0, report_memory_usage
|
||||
|
||||
from ._base_hook import BaseHook
|
||||
from ._commons_ import _format_number
|
||||
from colossalai.trainer.hooks._metric_hook import ThroughputMetric
|
||||
|
||||
|
||||
class LogByEpochHook(BaseHook):
|
@@ -1,6 +1,7 @@
|
||||
from colossalai.registry import HOOKS
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.registry import HOOKS
|
||||
|
||||
from ._metric_hook import LearningRateMetric, MetricHook
|
||||
|
||||
|
@@ -6,6 +6,7 @@ from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.communication import all_reduce
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
@@ -19,8 +20,8 @@ from ._commons_ import _format_number
|
||||
class Metric(ABC):
|
||||
"""A basic class of metric collectors. It collects a specific
|
||||
metric during training or evaluation and would always be used with
|
||||
:class:`MetricHook` to help it update its states and show the
|
||||
metric. So please use corresponding hook class to make the metric
|
||||
:class:`MetricHook` to help it update its states and show the
|
||||
metric. So please use corresponding hook class to make the metric
|
||||
collector works.
|
||||
|
||||
Args:
|
||||
@@ -220,9 +221,9 @@ class AccuracyMetric(Metric):
|
||||
|
||||
|
||||
class MetricHook(BaseHook):
|
||||
"""Specialized hook classes for :class:`Metric`.
|
||||
Some help metric collectors initialize, reset and
|
||||
update their states. Others are used to display and
|
||||
"""Specialized hook classes for :class:`Metric`.
|
||||
Some help metric collectors initialize, reset and
|
||||
update their states. Others are used to display and
|
||||
record the metric.
|
||||
|
||||
Args:
|
Reference in New Issue
Block a user