[legacy] move builder and registry to legacy (#4603)

This commit is contained in:
Hongxin Liu
2023-09-04 19:56:42 +08:00
parent 8accecd55b
commit ac178ca5c1
65 changed files with 353 additions and 332 deletions

View File

@@ -2,9 +2,9 @@
# -*- encoding: utf-8 -*-
import torch
from colossalai.legacy.registry import HOOKS
from colossalai.legacy.trainer.hooks import BaseHook
from colossalai.logging import get_dist_logger
from colossalai.registry import HOOKS
from colossalai.utils.checkpointing import save_checkpoint
from ._lr_scheduler_hook import LRSchedulerHook

View File

@@ -7,9 +7,9 @@ from typing import List
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.registry import HOOKS
from colossalai.legacy.trainer.hooks._metric_hook import ThroughputMetric
from colossalai.logging import DistributedLogger
from colossalai.registry import HOOKS
from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0, report_memory_usage
from ._base_hook import BaseHook

View File

@@ -1,6 +1,6 @@
from torch import Tensor
from colossalai.registry import HOOKS
from colossalai.legacy.registry import HOOKS
from ._metric_hook import LearningRateMetric, MetricHook

View File

@@ -10,7 +10,7 @@ import torch.distributed as dist
from colossalai.communication import all_reduce
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import HOOKS
from colossalai.legacy.registry import HOOKS
from colossalai.utils import get_current_device, is_no_pp_or_last_stage
from ._base_hook import BaseHook
@@ -356,7 +356,7 @@ class ThroughputMetric(Metric):
self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA)
else:
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
gpc.get_world_size(ParallelMode.DATA)
gpc.get_world_size(ParallelMode.DATA)
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())
@@ -367,7 +367,7 @@ class ThroughputMetric(Metric):
self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA)
else:
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
gpc.get_world_size(ParallelMode.DATA)
gpc.get_world_size(ParallelMode.DATA)
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())