From 807e01a4bae5d1c49747bcb4ae69c98871bce9ff Mon Sep 17 00:00:00 2001
From: Hongxin Liu <lhx0217@gmail.com>
Date: Tue, 5 Sep 2023 15:04:02 +0800
Subject: [PATCH 1/4] [zero] hotfix master param sync (#4618)

* [zero] add method to update master params

* [zero] update zero plugin

* [plugin] update low level zero plugin
---
 .../booster/plugin/low_level_zero_plugin.py   | 123 ++++++++++++------
 colossalai/interface/__init__.py              |   4 +-
 colossalai/interface/model.py                 |  11 ++
 colossalai/zero/low_level/low_level_optim.py  |  17 +++
 .../test_low_level_zero_checkpoint_io.py      |  12 ++
 5 files changed, 122 insertions(+), 45 deletions(-)

diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py
index 6efafc56d..9adb4beec 100644
--- a/colossalai/booster/plugin/low_level_zero_plugin.py
+++ b/colossalai/booster/plugin/low_level_zero_plugin.py
@@ -3,6 +3,7 @@ import os
 import warnings
 from functools import partial
 from pathlib import Path
+from types import MethodType
 from typing import Callable, Iterator, List, Optional, Tuple, Union
 
 import torch
@@ -25,9 +26,9 @@ from colossalai.checkpoint_io.utils import (
     sharded_optimizer_loading_epilogue,
     unwrap_optimizer,
 )
-from colossalai.interface import ModelWrapper, OptimizerWrapper
+from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
 from colossalai.utils import get_current_device
-from colossalai.zero import LowLevelZeroOptimizer, zero_model_wrapper, zero_optim_wrapper
+from colossalai.zero import LowLevelZeroOptimizer
 
 from .dp_plugin_base import DPPluginBase
 from .torch_ddp_plugin import TorchDDPCheckpointIO
@@ -44,6 +45,34 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
 SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']
 
 
+class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
+
+    def __init__(self, module: nn.Module, precision: str) -> None:
+        super().__init__(module)
+        self.dtype = None
+        if precision == 'fp16':
+            self.dtype = torch.float16
+        elif precision == 'bf16':
+            self.dtype = torch.bfloat16
+        if self.dtype is not None:
+            module = module.to(self.dtype)
+        module = module.to(get_current_device())
+        self.module = module
+        self.convert_fn = None
+        if self.dtype is not None:
+            self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
+
+    def forward(self, *args, **kwargs):
+        if self.convert_fn is not None:
+            args = tree_map(self.convert_fn, args)
+            kwargs = tree_map(self.convert_fn, kwargs)
+        return super().forward(*args, **kwargs)
+
+    def unwrap(self):
+        # TODO(ver217): this is a workaround for loading model
+        return self
+
+
 class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
 
     def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
@@ -165,30 +194,36 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
 
         sharded_optimizer_loading_epilogue(optimizer)
 
+    def save_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool,
+                             use_safetensors: bool):
+        assert isinstance(model, LowLevelZeroModel)
+        super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors)
 
-class LowLevelZeroModel(ModelWrapper):
+    def save_sharded_model(self,
+                           model: nn.Module,
+                           checkpoint_path: str,
+                           gather_dtensor: bool = True,
+                           prefix: Optional[str] = None,
+                           max_shard_size: int = 1024,
+                           use_safetensors: bool = False):
+        assert isinstance(model, LowLevelZeroModel)
+        super().save_sharded_model(model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size,
+                                   use_safetensors)
 
-    def __init__(self, module: nn.Module, stage: int, precision: str) -> None:
-        super().__init__(module)
-        self.dtype = None
-        if precision == 'fp16':
-            self.dtype = torch.float16
-        elif precision == 'bf16':
-            self.dtype = torch.bfloat16
-        module = zero_model_wrapper(module, zero_stage=stage)
-        if self.dtype is not None:
-            module = module.to(self.dtype)
-        module = module.to(get_current_device())
-        self.module = module
-        self.convert_fn = None
-        if self.dtype is not None:
-            self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
+    def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True):
+        assert isinstance(model, LowLevelZeroModel)
+        super().load_unsharded_model(model.module, checkpoint, strict)
+        model.update_master_params()
 
-    def forward(self, *args, **kwargs):
-        if self.convert_fn is not None:
-            args = tree_map(self.convert_fn, args)
-            kwargs = tree_map(self.convert_fn, kwargs)
-        return super().forward(*args, **kwargs)
+    def load_sharded_model(self,
+                           model: LowLevelZeroModel,
+                           checkpoint_index_file: Path,
+                           strict: bool = False,
+                           use_safetensors: bool = False,
+                           load_sub_module: bool = True):
+        assert isinstance(model, LowLevelZeroModel)
+        super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module)
+        model.update_master_params()
 
 
 class LowLevelZeroPlugin(DPPluginBase):
@@ -248,22 +283,24 @@ class LowLevelZeroPlugin(DPPluginBase):
         super().__init__()
         assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
         assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training'
-
+        assert norm_type == 2.0, f'LowLevelZeroPlugin only supports norm_type=2.0 now'
         self.stage = stage
         self.precision = precision
-        self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
-                                      communication_dtype=communication_dtype,
-                                      overlap_communication=overlap_communication,
-                                      cpu_offload=cpu_offload)
-        self.optim_kwargs = dict(initial_scale=initial_scale,
-                                 growth_factor=growth_factor,
-                                 backoff_factor=backoff_factor,
-                                 growth_interval=growth_interval,
-                                 hysteresis=hysteresis,
-                                 min_scale=min_scale,
-                                 max_scale=max_scale,
-                                 max_norm=max_norm,
-                                 norm_type=norm_type)
+        self.zero_optim_kwargs = dict(
+            initial_scale=initial_scale,
+            growth_factor=growth_factor,
+            backoff_factor=backoff_factor,
+            growth_interval=growth_interval,
+            hysteresis=hysteresis,
+            min_scale=min_scale,
+            max_scale=max_scale,
+            clip_grad_norm=max_norm,
+            reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
+            communication_dtype=communication_dtype,
+            overlap_communication=overlap_communication,
+            cpu_offload=cpu_offload,
+            partition_grad=(stage == 2),
+        )
         self.verbose = verbose
 
         # set class name with stage, for better error message
@@ -294,15 +331,15 @@ class LowLevelZeroPlugin(DPPluginBase):
     ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
 
         if not isinstance(model, ModelWrapper):
-            model = LowLevelZeroModel(model, self.stage, self.precision)
+            model = LowLevelZeroModel(model, self.precision)
 
         if optimizer is not None and \
                 not isinstance(optimizer, OptimizerWrapper):
-            optimizer = zero_optim_wrapper(model.unwrap(),
-                                           optimizer,
-                                           optim_config=self.zero_optim_config,
-                                           **self.optim_kwargs,
-                                           verbose=self.verbose)
+            optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(optimizer,
+                                                                     **self.zero_optim_kwargs,
+                                                                     verbose=self.verbose)
+            # inject update_master_params
+            model.update_master_params = MethodType(optimizer.update_master_params, model)
 
         return model, optimizer, criterion, dataloader, lr_scheduler
 
diff --git a/colossalai/interface/__init__.py b/colossalai/interface/__init__.py
index 8c658e375..1c3199fc1 100644
--- a/colossalai/interface/__init__.py
+++ b/colossalai/interface/__init__.py
@@ -1,4 +1,4 @@
-from .model import ModelWrapper
+from .model import AMPModelMixin, ModelWrapper
 from .optimizer import OptimizerWrapper
 
-__all__ = ['OptimizerWrapper', 'ModelWrapper']
+__all__ = ['OptimizerWrapper', 'ModelWrapper', 'AMPModelMixin']
diff --git a/colossalai/interface/model.py b/colossalai/interface/model.py
index a067d7671..7b3d9435d 100644
--- a/colossalai/interface/model.py
+++ b/colossalai/interface/model.py
@@ -23,3 +23,14 @@ class ModelWrapper(nn.Module):
 
     def forward(self, *args, **kwargs):
         return self.module(*args, **kwargs)
+
+
+class AMPModelMixin:
+    """This mixin class defines the interface for AMP training.
+    """
+
+    def update_master_params(self):
+        """
+        Update the master parameters for AMP training.
+        """
+        pass
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index b4439ab19..d9d6298d7 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -6,6 +6,7 @@ from typing import Dict, Iterator, Optional, Tuple
 
 import torch
 import torch.distributed as dist
+import torch.nn as nn
 from torch.distributed import ProcessGroup
 from torch.optim import Optimizer
 
@@ -600,3 +601,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
             ret_block_size += current_block_size
 
         yield ret_block, ret_block_size
+
+    def update_master_params(self, model: nn.Module) -> None:
+        """Update master params from working params
+
+        Args:
+            model (nn.Module): The model to update master params
+        """
+        for p in model.parameters():
+            p_id = id(p)
+            if p_id in self._param_store.working_to_master_param:
+                master_param = self._param_store.working_to_master_param[p_id]
+                padding_size = self._param_store.get_param_padding_size(p)
+                working_param = p.data.view(-1)
+                if padding_size > 0:
+                    working_param = torch.nn.functional.pad(working_param, [0, padding_size])
+                master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py
index 3faa395b5..7ee733b26 100644
--- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py
+++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py
@@ -14,6 +14,7 @@ from colossalai.testing import (
     rerun_if_address_is_in_use,
     spawn,
 )
+from colossalai.zero import LowLevelZeroOptimizer
 
 
 # stage 1 and 2 process the optimizer/mode the same way
@@ -50,6 +51,17 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
 
         booster.load_model(new_model, model_ckpt_path)
         check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
+        # check master weight
+        assert isinstance(new_optimizer, LowLevelZeroOptimizer)
+        working_param_id_set = set(id(p) for p in new_model.parameters())
+        for p_id, master_param in new_optimizer._param_store.working_to_master_param.items():
+            assert p_id in working_param_id_set
+            working_param = new_optimizer._param_store.master_to_working_param[id(master_param)]
+            padding = new_optimizer._param_store.get_param_padding_size(working_param)
+            padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding))
+            working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()]
+            assert torch.equal(working_shard,
+                               master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device))
 
         booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
         check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)

From 89fe0277875146cc521f1e15e508efd43e56f34c Mon Sep 17 00:00:00 2001
From: Hongxin Liu <lhx0217@gmail.com>
Date: Thu, 31 Aug 2023 13:51:28 +0800
Subject: [PATCH 2/4] [legacy] move trainer to legacy (#4545)

* [legacy] move trainer to legacy

* [doc] update docs related to trainer

* [test] ignore legacy test
---
 colossalai/legacy/__init__.py                 |   0
 colossalai/{ => legacy}/trainer/__init__.py   |   0
 colossalai/{ => legacy}/trainer/_trainer.py   |   7 +-
 .../{ => legacy}/trainer/hooks/__init__.py    |   9 +-
 .../{ => legacy}/trainer/hooks/_base_hook.py  |   0
 .../trainer/hooks/_checkpoint_hook.py         |   5 +-
 .../{ => legacy}/trainer/hooks/_commons_.py   |   0
 .../{ => legacy}/trainer/hooks/_log_hook.py   |  10 +-
 .../trainer/hooks/_lr_scheduler_hook.py       |   3 +-
 .../trainer/hooks/_metric_hook.py             |  11 +-
 .../train_gpt_using_hybrid_parallelism.md     |   3 +-
 .../train_vit_using_pipeline_parallelism.md   |   3 +-
 .../train_vit_with_hybrid_parallelism.md      |   3 +-
 docs/source/en/basics/engine_trainer.md       |   7 +-
 docs/source/en/basics/model_checkpoint.md     |   3 +-
 .../en/features/mixed_precision_training.md   |   2 +-
 docs/source/en/features/pipeline_parallel.md  |   3 +-
 .../train_gpt_using_hybrid_parallelism.md     |   3 +-
 .../train_vit_using_pipeline_parallelism.md   |   3 +-
 .../train_vit_with_hybrid_parallelism.md      |   3 +-
 docs/source/zh-Hans/basics/engine_trainer.md  |   7 +-
 .../source/zh-Hans/basics/model_checkpoint.md |   3 +-
 .../features/mixed_precision_training.md      |   2 +-
 .../zh-Hans/features/pipeline_parallel.md     |   3 +-
 examples/language/gpt/titans/train_gpt.py     |   2 +-
 pytest.ini                                    |   2 +-
 .../test_cifar_with_data_pipeline_tensor.py   | 100 ------------------
 .../test_trainer/test_pipeline/test_p2p.py    |   0
 .../test_pipeline/test_pipeline_schedule.py   |   0
 .../test_trainer_with_non_pipe_schedule.py    |   2 +-
 .../test_trainer_with_pipe_schedule.py        |   2 +-
 .../test_cuda_rpc_performance.py              |  15 +--
 32 files changed, 63 insertions(+), 153 deletions(-)
 create mode 100644 colossalai/legacy/__init__.py
 rename colossalai/{ => legacy}/trainer/__init__.py (100%)
 rename colossalai/{ => legacy}/trainer/_trainer.py (98%)
 rename colossalai/{ => legacy}/trainer/hooks/__init__.py (75%)
 rename colossalai/{ => legacy}/trainer/hooks/_base_hook.py (100%)
 rename colossalai/{ => legacy}/trainer/hooks/_checkpoint_hook.py (98%)
 rename colossalai/{ => legacy}/trainer/hooks/_commons_.py (100%)
 rename colossalai/{ => legacy}/trainer/hooks/_log_hook.py (98%)
 rename colossalai/{ => legacy}/trainer/hooks/_lr_scheduler_hook.py (99%)
 rename colossalai/{ => legacy}/trainer/hooks/_metric_hook.py (98%)
 delete mode 100644 tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py
 rename tests/{ => test_legacy}/test_trainer/test_pipeline/test_p2p.py (100%)
 rename tests/{ => test_legacy}/test_trainer/test_pipeline/test_pipeline_schedule.py (100%)
 rename tests/{ => test_legacy}/test_trainer/test_trainer_with_non_pipe_schedule.py (97%)
 rename tests/{ => test_legacy}/test_trainer/test_trainer_with_pipe_schedule.py (98%)

diff --git a/colossalai/legacy/__init__.py b/colossalai/legacy/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/colossalai/trainer/__init__.py b/colossalai/legacy/trainer/__init__.py
similarity index 100%
rename from colossalai/trainer/__init__.py
rename to colossalai/legacy/trainer/__init__.py
diff --git a/colossalai/trainer/_trainer.py b/colossalai/legacy/trainer/_trainer.py
similarity index 98%
rename from colossalai/trainer/_trainer.py
rename to colossalai/legacy/trainer/_trainer.py
index bfe1c403f..fb66acec5 100644
--- a/colossalai/trainer/_trainer.py
+++ b/colossalai/legacy/trainer/_trainer.py
@@ -1,14 +1,13 @@
-from typing import Union, List, Any
+from typing import Any, List, Union
 
 import torch
 from torch.utils.data import DataLoader
 from tqdm import tqdm
 
 from colossalai.engine import Engine
+from colossalai.legacy.trainer.hooks import BaseHook
 from colossalai.logging import DistributedLogger
-from colossalai.utils import MultiTimer
-from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
-from colossalai.trainer.hooks import BaseHook
+from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0
 
 
 class Trainer:
diff --git a/colossalai/trainer/hooks/__init__.py b/colossalai/legacy/trainer/hooks/__init__.py
similarity index 75%
rename from colossalai/trainer/hooks/__init__.py
rename to colossalai/legacy/trainer/hooks/__init__.py
index 4d3609383..bf9cc6421 100644
--- a/colossalai/trainer/hooks/__init__.py
+++ b/colossalai/legacy/trainer/hooks/__init__.py
@@ -1,7 +1,12 @@
 from ._base_hook import BaseHook
 from ._checkpoint_hook import SaveCheckpointHook
-from ._log_hook import (LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook, LogTimingByEpochHook,
-                        TensorboardHook)
+from ._log_hook import (
+    LogMemoryByEpochHook,
+    LogMetricByEpochHook,
+    LogMetricByStepHook,
+    LogTimingByEpochHook,
+    TensorboardHook,
+)
 from ._lr_scheduler_hook import LRSchedulerHook
 from ._metric_hook import AccuracyHook, LossHook, MetricHook, ThroughputHook
 
diff --git a/colossalai/trainer/hooks/_base_hook.py b/colossalai/legacy/trainer/hooks/_base_hook.py
similarity index 100%
rename from colossalai/trainer/hooks/_base_hook.py
rename to colossalai/legacy/trainer/hooks/_base_hook.py
diff --git a/colossalai/trainer/hooks/_checkpoint_hook.py b/colossalai/legacy/trainer/hooks/_checkpoint_hook.py
similarity index 98%
rename from colossalai/trainer/hooks/_checkpoint_hook.py
rename to colossalai/legacy/trainer/hooks/_checkpoint_hook.py
index 3bcb32cd2..7754ebcc3 100644
--- a/colossalai/trainer/hooks/_checkpoint_hook.py
+++ b/colossalai/legacy/trainer/hooks/_checkpoint_hook.py
@@ -1,11 +1,12 @@
 #!/usr/bin/env python
 # -*- encoding: utf-8 -*-
 import torch
-from colossalai.logging import get_dist_logger
 
+from colossalai.legacy.trainer.hooks import BaseHook
+from colossalai.logging import get_dist_logger
 from colossalai.registry import HOOKS
-from colossalai.trainer.hooks import BaseHook
 from colossalai.utils.checkpointing import save_checkpoint
+
 from ._lr_scheduler_hook import LRSchedulerHook
 
 
diff --git a/colossalai/trainer/hooks/_commons_.py b/colossalai/legacy/trainer/hooks/_commons_.py
similarity index 100%
rename from colossalai/trainer/hooks/_commons_.py
rename to colossalai/legacy/trainer/hooks/_commons_.py
diff --git a/colossalai/trainer/hooks/_log_hook.py b/colossalai/legacy/trainer/hooks/_log_hook.py
similarity index 98%
rename from colossalai/trainer/hooks/_log_hook.py
rename to colossalai/legacy/trainer/hooks/_log_hook.py
index 5b1f33983..1efc8be76 100644
--- a/colossalai/trainer/hooks/_log_hook.py
+++ b/colossalai/legacy/trainer/hooks/_log_hook.py
@@ -3,17 +3,17 @@
 
 import os
 import os.path as osp
-
 from typing import List
+
 from colossalai.context import ParallelMode
 from colossalai.core import global_context as gpc
-from colossalai.registry import HOOKS
+from colossalai.legacy.trainer.hooks._metric_hook import ThroughputMetric
 from colossalai.logging import DistributedLogger
-from colossalai.utils import report_memory_usage, is_dp_rank_0, \
-    is_tp_rank_0, is_no_pp_or_last_stage, MultiTimer
+from colossalai.registry import HOOKS
+from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0, report_memory_usage
+
 from ._base_hook import BaseHook
 from ._commons_ import _format_number
-from colossalai.trainer.hooks._metric_hook import ThroughputMetric
 
 
 class LogByEpochHook(BaseHook):
diff --git a/colossalai/trainer/hooks/_lr_scheduler_hook.py b/colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py
similarity index 99%
rename from colossalai/trainer/hooks/_lr_scheduler_hook.py
rename to colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py
index c6da33442..0d19ab08a 100644
--- a/colossalai/trainer/hooks/_lr_scheduler_hook.py
+++ b/colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py
@@ -1,6 +1,7 @@
-from colossalai.registry import HOOKS
 from torch import Tensor
 
+from colossalai.registry import HOOKS
+
 from ._metric_hook import LearningRateMetric, MetricHook
 
 
diff --git a/colossalai/trainer/hooks/_metric_hook.py b/colossalai/legacy/trainer/hooks/_metric_hook.py
similarity index 98%
rename from colossalai/trainer/hooks/_metric_hook.py
rename to colossalai/legacy/trainer/hooks/_metric_hook.py
index 526d6c746..96def4172 100644
--- a/colossalai/trainer/hooks/_metric_hook.py
+++ b/colossalai/legacy/trainer/hooks/_metric_hook.py
@@ -6,6 +6,7 @@ from typing import Callable
 
 import torch
 import torch.distributed as dist
+
 from colossalai.communication import all_reduce
 from colossalai.context import ParallelMode
 from colossalai.core import global_context as gpc
@@ -19,8 +20,8 @@ from ._commons_ import _format_number
 class Metric(ABC):
     """A basic class of metric collectors. It collects a specific
     metric during training or evaluation and would always be used with
-    :class:`MetricHook` to help it update its states and show the 
-    metric. So please use corresponding hook class to make the metric 
+    :class:`MetricHook` to help it update its states and show the
+    metric. So please use corresponding hook class to make the metric
     collector works.
 
     Args:
@@ -220,9 +221,9 @@ class AccuracyMetric(Metric):
 
 
 class MetricHook(BaseHook):
-    """Specialized hook classes for :class:`Metric`. 
-    Some help metric collectors initialize, reset and 
-    update their states. Others are used to display and 
+    """Specialized hook classes for :class:`Metric`.
+    Some help metric collectors initialize, reset and
+    update their states. Others are used to display and
     record the metric.
 
     Args:
diff --git a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
index 715c15eb6..24aa2610f 100644
--- a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
+++ b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
@@ -43,7 +43,7 @@ from colossalai.engine.schedule import (InterleavedPipelineSchedule,
                                         PipelineSchedule)
 from colossalai.logging import disable_existing_loggers, get_dist_logger
 from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.trainer import Trainer, hooks
 from colossalai.utils.timer import MultiTimer
 from model_zoo.gpt import GPTLMLoss
 from torch.nn import functional as F
@@ -268,3 +268,4 @@ def train():
         return_output_label=False,
     )
 ```
+<!-- doc-test-command: echo  -->
diff --git a/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md b/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md
index 6adfe4f11..3475d8f07 100644
--- a/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md
+++ b/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md
@@ -38,7 +38,7 @@ from colossalai.builder import build_pipeline_model
 from colossalai.engine.schedule import (InterleavedPipelineSchedule,
                                         PipelineSchedule)
 from colossalai.logging import disable_existing_loggers, get_dist_logger
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.trainer import Trainer, hooks
 from colossalai.utils import MultiTimer, get_dataloader
 from timm.models import vision_transformer as vit
 from torchvision import transforms
@@ -245,3 +245,4 @@ def train():
                 hooks=hook_list,
                 display_progress=True)
 ```
+<!-- doc-test-command: echo  -->
diff --git a/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md
index a2deaeb88..5b0b694b3 100644
--- a/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md
+++ b/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md
@@ -79,7 +79,7 @@ from colossalai.core import global_context as gpc
 from colossalai.logging import disable_existing_loggers, get_dist_logger
 from colossalai.nn.lr_scheduler import LinearWarmupLR
 from colossalai.nn.metric import Accuracy
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.trainer import Trainer, hooks
 ```
 
 - Other modules
@@ -644,3 +644,4 @@ torchrun --standalone --nproc_per_node <NUM_GPUs>  train_hybrid.py --config ./co
 # If your torch >= 1.9.0
 # python -m torch.distributed.run --standalone --nproc_per_node= <NUM_GPUs> train_hybrid.py --config ./configs/config_hybrid_parallel.py
 ```
+<!-- doc-test-command: echo  -->
diff --git a/docs/source/en/basics/engine_trainer.md b/docs/source/en/basics/engine_trainer.md
index d2f99563f..6d2355ad9 100644
--- a/docs/source/en/basics/engine_trainer.md
+++ b/docs/source/en/basics/engine_trainer.md
@@ -64,7 +64,7 @@ Trainer is a more high-level wrapper for the user to execute training with fewer
 
 ```python
 from colossalai.logging import get_dist_logger
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.trainer import Trainer, hooks
 
 # build components and initialize with colossalai.initialize
 ...
@@ -107,7 +107,7 @@ If you want to customize your own hook class, you can inherit `hooks.BaseHook` a
 
 ```python
 from colossalai.logging import get_dist_logger
-from colossalai.trainer import hooks
+from colossalai.legacy.trainer import hooks
 
 class LogMessageHook(hooks.BaseHook):
 
@@ -345,7 +345,7 @@ If you wish to train with a trainer object, you can follow the code snippet belo
 
 ```python
 from colossalai.nn.metric import Accuracy
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.trainer import Trainer, hooks
 
 
 # create a trainer object
@@ -387,3 +387,4 @@ python -m torch.distributed.launch --nproc_per_node <num_gpus> --master_addr loc
 # with trainer
 python -m torch.distributed.launch --nproc_per_node <num_gpus> --master_addr localhost --master_port 29500 run_resnet_cifar10_with_trainer.py
 ```
+<!-- doc-test-command: echo  -->
diff --git a/docs/source/en/basics/model_checkpoint.md b/docs/source/en/basics/model_checkpoint.md
index 70334f1c4..c3ba5b04b 100644
--- a/docs/source/en/basics/model_checkpoint.md
+++ b/docs/source/en/basics/model_checkpoint.md
@@ -41,7 +41,7 @@ for epoch in range(num_epochs):
 
 #### Save when using trainer
 ```python
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.trainer import Trainer, hooks
 model = ...
 engine, _, _, _ = colossalai.initialize(model=model, ...)
 trainer = Trainer(engine, ...)
@@ -61,3 +61,4 @@ model = ...
 load_checkpoint('xxx.pt', model)
 ... # train or test
 ```
+<!-- doc-test-command: echo  -->
diff --git a/docs/source/en/features/mixed_precision_training.md b/docs/source/en/features/mixed_precision_training.md
index 8579d586e..164b2a215 100644
--- a/docs/source/en/features/mixed_precision_training.md
+++ b/docs/source/en/features/mixed_precision_training.md
@@ -267,7 +267,7 @@ from pathlib import Path
 from colossalai.core import global_context as gpc
 from colossalai.logging import get_dist_logger
 from colossalai.utils import get_dataloader
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.trainer import Trainer, hooks
 from colossalai.nn.lr_scheduler import LinearWarmupLR
 from timm.models import vit_base_patch16_224
 from torchvision import datasets, transforms
diff --git a/docs/source/en/features/pipeline_parallel.md b/docs/source/en/features/pipeline_parallel.md
index 30654b0b0..8b5f228a9 100644
--- a/docs/source/en/features/pipeline_parallel.md
+++ b/docs/source/en/features/pipeline_parallel.md
@@ -79,7 +79,7 @@ import colossalai.nn as col_nn
 
 from colossalai.core import global_context as gpc
 from colossalai.logging import disable_existing_loggers, get_dist_logger
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.trainer import Trainer, hooks
 from colossalai.utils import MultiTimer, get_dataloader
 from colossalai.context import ParallelMode
 from colossalai.pipeline.pipelinable import PipelinableContext
@@ -157,3 +157,4 @@ trainer.fit(train_dataloader=train_dataloader,
 ```
 
 We use `2` pipeline stages and the batch will be split into `4` micro batches.
+<!-- doc-test-command: echo  -->
diff --git a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
index 6c6dcf6e8..a199d31e7 100644
--- a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
+++ b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
@@ -43,7 +43,7 @@ from colossalai.engine.schedule import (InterleavedPipelineSchedule,
                                         PipelineSchedule)
 from colossalai.logging import disable_existing_loggers, get_dist_logger
 from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.trainer import Trainer, hooks
 from colossalai.utils.timer import MultiTimer
 from model_zoo.gpt import GPTLMLoss
 from torch.nn import functional as F
@@ -273,3 +273,4 @@ def train():
         return_output_label=False,
     )
 ```
+<!-- doc-test-command: echo  -->
diff --git a/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md
index 495c7fa36..d3a98c89b 100644
--- a/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md
+++ b/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md
@@ -36,7 +36,7 @@ from colossalai.builder import build_pipeline_model
 from colossalai.engine.schedule import (InterleavedPipelineSchedule,
                                         PipelineSchedule)
 from colossalai.logging import disable_existing_loggers, get_dist_logger
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.trainer import Trainer, hooks
 from colossalai.utils import MultiTimer, get_dataloader
 from timm.models import vision_transformer as vit
 from torchvision import transforms
@@ -244,3 +244,4 @@ def train():
                 hooks=hook_list,
                 display_progress=True)
 ```
+<!-- doc-test-command: echo  -->
diff --git a/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md
index 5ad083920..ddc2502f0 100644
--- a/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md
+++ b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md
@@ -74,7 +74,7 @@ from colossalai.core import global_context as gpc
 from colossalai.logging import disable_existing_loggers, get_dist_logger
 from colossalai.nn.lr_scheduler import LinearWarmupLR
 from colossalai.nn.metric import Accuracy
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.trainer import Trainer, hooks
 ```
 
 - 其他模块
@@ -589,3 +589,4 @@ torchrun --standalone --nproc_per_node <NUM_GPUs>  train_hybrid.py --config ./co
 # If your torch >= 1.9.0
 # python -m torch.distributed.run --standalone --nproc_per_node= <NUM_GPUs> train_hybrid.py --config ./configs/config_hybrid_parallel.py
 ```
+<!-- doc-test-command: echo  -->
diff --git a/docs/source/zh-Hans/basics/engine_trainer.md b/docs/source/zh-Hans/basics/engine_trainer.md
index a35bd87c4..e57220292 100644
--- a/docs/source/zh-Hans/basics/engine_trainer.md
+++ b/docs/source/zh-Hans/basics/engine_trainer.md
@@ -61,7 +61,7 @@ Trainer 的参数 `schedule` 默认值是 `None` 。在大多数情况下,除
 
 ```python
 from colossalai.logging import get_dist_logger
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.trainer import Trainer, hooks
 
 # build components and initialize with colossalai.initialize
 ...
@@ -104,7 +104,7 @@ trainer.fit(
 
 ```python
 from colossalai.logging import get_dist_logger
-from colossalai.trainer import hooks
+from colossalai.legacy.trainer import hooks
 
 class LogMessageHook(hooks.BaseHook):
 
@@ -341,7 +341,7 @@ for epoch in range(gpc.config.NUM_EPOCHS):
 
 ```python
 from colossalai.nn.metric import Accuracy
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.trainer import Trainer, hooks
 
 
 # create a trainer object
@@ -384,3 +384,4 @@ python -m torch.distributed.launch --nproc_per_node <num_gpus> --master_addr loc
 # with trainer
 python -m torch.distributed.launch --nproc_per_node <num_gpus> --master_addr localhost --master_port 29500 run_resnet_cifar10_with_trainer.py
 ```
+<!-- doc-test-command: echo  -->
diff --git a/docs/source/zh-Hans/basics/model_checkpoint.md b/docs/source/zh-Hans/basics/model_checkpoint.md
index a5374b750..4a49d373a 100644
--- a/docs/source/zh-Hans/basics/model_checkpoint.md
+++ b/docs/source/zh-Hans/basics/model_checkpoint.md
@@ -41,7 +41,7 @@ for epoch in range(num_epochs):
 
 #### 用 trainer 保存
 ```python
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.trainer import Trainer, hooks
 model = ...
 engine, _, _, _ = colossalai.initialize(model=model, ...)
 trainer = Trainer(engine, ...)
@@ -61,3 +61,4 @@ model = ...
 load_checkpoint('xxx.pt', model)
 ... # train or test
 ```
+<!-- doc-test-command: echo  -->
diff --git a/docs/source/zh-Hans/features/mixed_precision_training.md b/docs/source/zh-Hans/features/mixed_precision_training.md
index a92e7e093..35a73f1ad 100644
--- a/docs/source/zh-Hans/features/mixed_precision_training.md
+++ b/docs/source/zh-Hans/features/mixed_precision_training.md
@@ -245,7 +245,7 @@ from pathlib import Path
 from colossalai.core import global_context as gpc
 from colossalai.logging import get_dist_logger
 from colossalai.utils import get_dataloader
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.trainer import Trainer, hooks
 from colossalai.nn.lr_scheduler import LinearWarmupLR
 from timm.models import vit_base_patch16_224
 from torchvision import datasets, transforms
diff --git a/docs/source/zh-Hans/features/pipeline_parallel.md b/docs/source/zh-Hans/features/pipeline_parallel.md
index 98096b1d7..1497dc399 100644
--- a/docs/source/zh-Hans/features/pipeline_parallel.md
+++ b/docs/source/zh-Hans/features/pipeline_parallel.md
@@ -78,7 +78,7 @@ import colossalai.nn as col_nn
 
 from colossalai.core import global_context as gpc
 from colossalai.logging import disable_existing_loggers, get_dist_logger
-from colossalai.trainer import Trainer, hooks
+from colossalai.legacy.trainer import Trainer, hooks
 from colossalai.utils import MultiTimer, get_dataloader
 from colossalai.context import ParallelMode
 from colossalai.pipeline.pipelinable import PipelinableContext
@@ -156,3 +156,4 @@ trainer.fit(train_dataloader=train_dataloader,
 ```
 
 我们使用 `2` 个流水段,并且 batch 将被切分为 `4` 个 micro batches。
+<!-- doc-test-command: echo  -->
diff --git a/examples/language/gpt/titans/train_gpt.py b/examples/language/gpt/titans/train_gpt.py
index 6be0b9e8d..b239b626c 100644
--- a/examples/language/gpt/titans/train_gpt.py
+++ b/examples/language/gpt/titans/train_gpt.py
@@ -10,9 +10,9 @@ import colossalai
 import colossalai.utils as utils
 from colossalai.context.parallel_mode import ParallelMode
 from colossalai.core import global_context as gpc
+from colossalai.legacy.trainer import Trainer, hooks
 from colossalai.logging import disable_existing_loggers, get_dist_logger
 from colossalai.nn import LinearWarmupLR
-from colossalai.trainer import Trainer, hooks
 from colossalai.utils import colo_set_process_memory_fraction, is_using_pp
 from colossalai.utils.timer import MultiTimer
 from colossalai.zero.legacy.init_ctx import ZeroInitContext
diff --git a/pytest.ini b/pytest.ini
index d25865d52..b869bb4fa 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -4,4 +4,4 @@ markers =
     gpu: tests which requires a single GPU
     dist: tests which are run in a multi-GPU or multi-machine environment
     experiment: tests for experimental features
-addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe --ignore=tests/test_fx
+addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe --ignore=tests/test_fx --ignore=tests/test_legacy
diff --git a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py
deleted file mode 100644
index 4992acbd7..000000000
--- a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py
+++ /dev/null
@@ -1,100 +0,0 @@
-import os
-from pathlib import Path
-
-import pytest
-import torch
-from torchvision import transforms
-from torchvision.datasets import CIFAR10
-
-import colossalai
-from colossalai.amp import AMP_TYPE
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.logging import get_dist_logger
-from colossalai.nn import CrossEntropyLoss
-from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
-from colossalai.pipeline.pipelinable import PipelinableContext
-from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
-from colossalai.trainer import Trainer, hooks
-from colossalai.utils import get_dataloader
-
-BATCH_SIZE = 4
-NUM_EPOCHS = 60
-WARMUP_EPOCHS = 5
-CONFIG = dict(NUM_MICRO_BATCHES=2,
-              parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
-              fp16=dict(mode=AMP_TYPE.NAIVE),
-              gradient_accumulation=2)
-
-
-def run_trainer(rank, world_size, port):
-    colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
-
-    logger = get_dist_logger()
-
-    # get logger
-    logger = get_dist_logger()
-
-    pipelinable = PipelinableContext()
-    try:
-        from titans.model.vit import vit_tiny_patch4_32
-    except ImportError:
-        logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed')
-        logger.warning('please install titan from https://github.com/hpcaitech/Titans')
-        return
-    with pipelinable:
-        model = vit_tiny_patch4_32()
-    pipelinable.to_layer_list()
-    pipelinable.policy = "uniform"
-    model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
-
-    # create dataloaders
-    root = Path(os.environ['DATA'])
-    transform_train = transforms.Compose([
-        transforms.RandomCrop(32, padding=4, pad_if_needed=True),
-        transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
-        transforms.ToTensor(),
-        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
-    ])
-    train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train)
-    train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True)
-
-    # create loss function
-    criterion = CrossEntropyLoss(label_smoothing=0.1)
-
-    # create optimizer
-    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0)
-
-    # create lr scheduler
-    lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS)
-
-    # initialize
-    engine, train_dataloader, *_ = colossalai.initialize(model=model,
-                                                         optimizer=optimizer,
-                                                         criterion=criterion,
-                                                         train_dataloader=train_dataloader)
-
-    logger = get_dist_logger()
-
-    trainer = Trainer(engine=engine, logger=logger)
-
-    hook_list = [
-        hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
-    ]
-
-    trainer.fit(train_dataloader=train_dataloader,
-                epochs=NUM_EPOCHS,
-                max_steps=2,
-                hooks=hook_list,
-                display_progress=True)
-
-
-@pytest.mark.dist
-@skip_if_not_enough_gpus(min_gpus=8)
-@rerun_if_address_is_in_use()
-def test_hybrid_parallel():
-    spawn(run_trainer, 8)
-
-
-if __name__ == '__main__':
-    test_hybrid_parallel()
diff --git a/tests/test_trainer/test_pipeline/test_p2p.py b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py
similarity index 100%
rename from tests/test_trainer/test_pipeline/test_p2p.py
rename to tests/test_legacy/test_trainer/test_pipeline/test_p2p.py
diff --git a/tests/test_trainer/test_pipeline/test_pipeline_schedule.py b/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py
similarity index 100%
rename from tests/test_trainer/test_pipeline/test_pipeline_schedule.py
rename to tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py
diff --git a/tests/test_trainer/test_trainer_with_non_pipe_schedule.py b/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py
similarity index 97%
rename from tests/test_trainer/test_trainer_with_non_pipe_schedule.py
rename to tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py
index 753f82222..dab0e53a4 100644
--- a/tests/test_trainer/test_trainer_with_non_pipe_schedule.py
+++ b/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py
@@ -3,9 +3,9 @@ import torch
 
 import colossalai
 from colossalai.amp.amp_type import AMP_TYPE
+from colossalai.legacy.trainer import Trainer
 from colossalai.logging import get_dist_logger
 from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
-from colossalai.trainer import Trainer
 from colossalai.utils import MultiTimer
 from tests.components_to_test.registry import non_distributed_component_funcs
 
diff --git a/tests/test_trainer/test_trainer_with_pipe_schedule.py b/tests/test_legacy/test_trainer/test_trainer_with_pipe_schedule.py
similarity index 98%
rename from tests/test_trainer/test_trainer_with_pipe_schedule.py
rename to tests/test_legacy/test_trainer/test_trainer_with_pipe_schedule.py
index bb63d51a0..7dfbec854 100644
--- a/tests/test_trainer/test_trainer_with_pipe_schedule.py
+++ b/tests/test_legacy/test_trainer/test_trainer_with_pipe_schedule.py
@@ -12,9 +12,9 @@ from torchvision.models import resnet18
 import colossalai
 from colossalai.context.parallel_mode import ParallelMode
 from colossalai.core import global_context as gpc
+from colossalai.legacy.trainer import Trainer
 from colossalai.logging import get_dist_logger
 from colossalai.testing import rerun_if_address_is_in_use, spawn
-from colossalai.trainer import Trainer
 from colossalai.utils import MultiTimer, get_dataloader
 
 BATCH_SIZE = 4
diff --git a/tests/test_pipeline/test_cuda_rpc_performance.py b/tests/test_pipeline/test_cuda_rpc_performance.py
index 6a0509555..4bacb2181 100644
--- a/tests/test_pipeline/test_cuda_rpc_performance.py
+++ b/tests/test_pipeline/test_cuda_rpc_performance.py
@@ -1,25 +1,16 @@
 import os
-from typing import Callable, List, Optional, Type, Union
 import time
 
 import pytest
 import torch
 import torch.nn as nn
+from rpc_test_utils import parse_args, rpc_run
 from titans.dataloader.cifar10 import build_cifar
 from torchvision.models import resnet50
-from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1
 from tqdm import tqdm
 
-from rpc_test_utils import rpc_run, parse_args
-import colossalai
-import colossalai.nn as col_nn
-from colossalai.logging import disable_existing_loggers, get_dist_logger
-from colossalai.trainer import Trainer, hooks
-from colossalai.utils import MultiTimer, get_dataloader
-from colossalai.context import ParallelMode
-from colossalai.pipeline.pipelinable import PipelinableContext, PipelinableModel
-from colossalai.pipeline.rpc import OneFOneBPipelineEngine, ChimeraPipelineEngine
-from colossalai.pipeline.pipeline_process_group import ppg
+from colossalai.pipeline.pipelinable import PipelinableContext
+from colossalai.pipeline.rpc import OneFOneBPipelineEngine
 
 
 def flatten(x):

From 8accecd55bf1a5aaaeb4b84c06fac0d63850fd5e Mon Sep 17 00:00:00 2001
From: Hongxin Liu <lhx0217@gmail.com>
Date: Mon, 4 Sep 2023 11:33:40 +0800
Subject: [PATCH 3/4] [legacy] move engine to legacy (#4560)

* [legacy] move engine to legacy

* [example] fix seq parallel example

* [example] fix seq parallel example

* [test] test gemini pluging hang

* [test] test gemini pluging hang

* [test] test gemini pluging hang

* [test] test gemini pluging hang

* [test] test gemini pluging hang

* [example] update seq parallel requirements
---
 colossalai/builder/builder.py                 |  2 +-
 colossalai/initialize.py                      |  6 +-
 colossalai/{ => legacy}/engine/__init__.py    |  0
 .../{ => legacy}/engine/_base_engine.py       | 12 ++-
 .../engine/gradient_accumulation/__init__.py  |  4 +-
 .../_gradient_accumulation.py                 |  4 +-
 .../engine/gradient_handler/__init__.py       |  0
 .../_base_gradient_handler.py                 |  0
 .../_data_parallel_gradient_handler.py        |  2 +-
 .../gradient_handler/_moe_gradient_handler.py |  2 +-
 .../_pipeline_parallel_gradient_handler.py    |  0
 .../_sequence_parallel_gradient_handler.py    |  2 +-
 .../_zero_gradient_handler.py                 |  0
 .../engine/gradient_handler/utils.py          |  0
 .../{ => legacy}/engine/schedule/__init__.py  |  0
 .../engine/schedule/_base_schedule.py         |  2 +-
 .../engine/schedule/_non_pipeline_schedule.py |  2 +-
 .../engine/schedule/_pipeline_schedule.py     | 10 +--
 .../engine/schedule/_pipeline_schedule_v2.py  |  2 +-
 colossalai/legacy/trainer/_trainer.py         |  2 +-
 colossalai/utils/profiler/profiler.py         | 18 ++---
 .../profiler/stateful_tensor_mem_extention.py |  8 +-
 .../advanced_tutorials/add_your_parallel.md   |  7 +-
 .../train_gpt_using_hybrid_parallelism.md     |  2 +-
 .../train_vit_using_pipeline_parallelism.md   |  2 +-
 .../train_vit_with_hybrid_parallelism.md      |  2 +-
 docs/source/en/features/gradient_handler.md   |  3 +-
 .../advanced_tutorials/add_your_parallel.md   |  7 +-
 .../train_gpt_using_hybrid_parallelism.md     |  2 +-
 .../train_vit_using_pipeline_parallelism.md   |  2 +-
 .../train_vit_with_hybrid_parallelism.md      |  2 +-
 .../zh-Hans/features/gradient_handler.md      |  3 +-
 .../data/datasets/indexed_dataset.py          | 77 +++++++------------
 .../sequence_parallel/requirements.txt        |  1 +
 examples/tutorial/sequence_parallel/train.py  |  2 +-
 .../test_plugin/test_gemini_plugin.py         |  2 +-
 tests/test_moe/test_grad_handler.py           |  2 +-
 tests/test_moe/test_moe_zero_model.py         |  2 +-
 tests/test_moe/test_moe_zero_optim.py         |  2 +-
 39 files changed, 93 insertions(+), 105 deletions(-)
 rename colossalai/{ => legacy}/engine/__init__.py (100%)
 rename colossalai/{ => legacy}/engine/_base_engine.py (97%)
 rename colossalai/{ => legacy}/engine/gradient_accumulation/__init__.py (94%)
 rename colossalai/{ => legacy}/engine/gradient_accumulation/_gradient_accumulation.py (98%)
 rename colossalai/{ => legacy}/engine/gradient_handler/__init__.py (100%)
 rename colossalai/{ => legacy}/engine/gradient_handler/_base_gradient_handler.py (100%)
 rename colossalai/{ => legacy}/engine/gradient_handler/_data_parallel_gradient_handler.py (94%)
 rename colossalai/{ => legacy}/engine/gradient_handler/_moe_gradient_handler.py (97%)
 rename colossalai/{ => legacy}/engine/gradient_handler/_pipeline_parallel_gradient_handler.py (100%)
 rename colossalai/{ => legacy}/engine/gradient_handler/_sequence_parallel_gradient_handler.py (94%)
 rename colossalai/{ => legacy}/engine/gradient_handler/_zero_gradient_handler.py (100%)
 rename colossalai/{ => legacy}/engine/gradient_handler/utils.py (100%)
 rename colossalai/{ => legacy}/engine/schedule/__init__.py (100%)
 rename colossalai/{ => legacy}/engine/schedule/_base_schedule.py (98%)
 rename colossalai/{ => legacy}/engine/schedule/_non_pipeline_schedule.py (97%)
 rename colossalai/{ => legacy}/engine/schedule/_pipeline_schedule.py (98%)
 rename colossalai/{ => legacy}/engine/schedule/_pipeline_schedule_v2.py (98%)

diff --git a/colossalai/builder/builder.py b/colossalai/builder/builder.py
index 4a9076013..a14509392 100644
--- a/colossalai/builder/builder.py
+++ b/colossalai/builder/builder.py
@@ -71,7 +71,7 @@ def build_gradient_handler(config, model, optimizer):
         optimizer (:class:`torch.optim.Optimizer`): An optimizer object containing parameters for the gradient handler
 
     Returns:
-        An object of :class:`colossalai.engine.BaseGradientHandler`
+        An object of :class:`colossalai.legacy.engine.BaseGradientHandler`
     """
     config_ = config.copy()
     config_['model'] = model
diff --git a/colossalai/initialize.py b/colossalai/initialize.py
index dc0df0517..32354dde8 100644
--- a/colossalai/initialize.py
+++ b/colossalai/initialize.py
@@ -21,9 +21,9 @@ from colossalai.builder.builder import build_gradient_handler
 from colossalai.context import Config, ConfigException, ParallelMode
 from colossalai.context.moe_context import MOE_CONTEXT
 from colossalai.core import global_context as gpc
-from colossalai.engine import Engine
-from colossalai.engine.gradient_accumulation import accumulate_gradient
-from colossalai.engine.schedule import (
+from colossalai.legacy.engine import Engine
+from colossalai.legacy.engine.gradient_accumulation import accumulate_gradient
+from colossalai.legacy.engine.schedule import (
     InterleavedPipelineSchedule,
     NonPipelineSchedule,
     PipelineSchedule,
diff --git a/colossalai/engine/__init__.py b/colossalai/legacy/engine/__init__.py
similarity index 100%
rename from colossalai/engine/__init__.py
rename to colossalai/legacy/engine/__init__.py
diff --git a/colossalai/engine/_base_engine.py b/colossalai/legacy/engine/_base_engine.py
similarity index 97%
rename from colossalai/engine/_base_engine.py
rename to colossalai/legacy/engine/_base_engine.py
index db27ad0e8..9af4469f4 100644
--- a/colossalai/engine/_base_engine.py
+++ b/colossalai/legacy/engine/_base_engine.py
@@ -8,11 +8,17 @@ from torch import Tensor
 from torch.nn import Module
 from torch.nn.modules.loss import _Loss
 
-from colossalai.engine.gradient_handler import BaseGradientHandler
-from colossalai.engine.schedule import BaseSchedule, InterleavedPipelineSchedule, NonPipelineSchedule, PipelineSchedule
+from colossalai.legacy.engine.gradient_handler import BaseGradientHandler
+from colossalai.legacy.engine.schedule import (
+    BaseSchedule,
+    InterleavedPipelineSchedule,
+    NonPipelineSchedule,
+    PipelineSchedule,
+)
 from colossalai.logging import get_dist_logger
-from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively
 from colossalai.nn.optimizer import ColossalaiOptimizer
+from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively
+
 
 class Engine:
     """Basic engine class for training and evaluation. It runs a specific process method
diff --git a/colossalai/engine/gradient_accumulation/__init__.py b/colossalai/legacy/engine/gradient_accumulation/__init__.py
similarity index 94%
rename from colossalai/engine/gradient_accumulation/__init__.py
rename to colossalai/legacy/engine/gradient_accumulation/__init__.py
index 4cb6f4ad7..670c26d06 100644
--- a/colossalai/engine/gradient_accumulation/__init__.py
+++ b/colossalai/legacy/engine/gradient_accumulation/__init__.py
@@ -4,7 +4,7 @@ import torch.nn as nn
 from torch.optim import Optimizer
 from torch.optim.lr_scheduler import _LRScheduler
 
-from colossalai.engine import BaseGradientHandler
+from colossalai.legacy.engine import BaseGradientHandler
 
 from ._gradient_accumulation import (
     GradAccumDataloader,
@@ -33,7 +33,7 @@ def accumulate_gradient(model: nn.Module,
         dataloader (:class:`torch.utils.data.DataLoader` or iterable objects):
             your dataloader object, would be called like iter(dataloader)
         accumulate_size (int): the number of steps to accumulate gradients
-        gradient_handlers (List[:class:`colossalai.engine.BaseGradientHandler`]):
+        gradient_handlers (List[:class:`colossalai.legacy.engine.BaseGradientHandler`]):
             list of gradient handler objects. Default is None.
         lr_scheduler (`torch.optim.lr_scheduler` or `colossalai.nn.lr_scheduler`):
             your ``lr_scheduler`` object for gradient accumulation. Defaults to None.
diff --git a/colossalai/engine/gradient_accumulation/_gradient_accumulation.py b/colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py
similarity index 98%
rename from colossalai/engine/gradient_accumulation/_gradient_accumulation.py
rename to colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py
index cf66be1cd..c466f7e2d 100644
--- a/colossalai/engine/gradient_accumulation/_gradient_accumulation.py
+++ b/colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py
@@ -10,7 +10,7 @@ from torch.optim import Optimizer
 from torch.optim.lr_scheduler import _LRScheduler
 from torch.utils.data import DataLoader
 
-from colossalai.engine import BaseGradientHandler
+from colossalai.legacy.engine import BaseGradientHandler
 from colossalai.nn.optimizer import ColossalaiOptimizer
 from colossalai.utils import conditional_context
 
@@ -262,7 +262,7 @@ class GradAccumGradientHandler:
     before accumulation size is reached.
 
     Args:
-        grad_handler (:class:`colossalai.engine.BaseGradientHandler`):
+        grad_handler (:class:`colossalai.legacy.engine.BaseGradientHandler`):
             Your ``gradient_handler`` object for gradient accumulation, would be called when achieving `accumulate_size`.
         accumulate_size (int): The number of steps to accumulate gradients.
 
diff --git a/colossalai/engine/gradient_handler/__init__.py b/colossalai/legacy/engine/gradient_handler/__init__.py
similarity index 100%
rename from colossalai/engine/gradient_handler/__init__.py
rename to colossalai/legacy/engine/gradient_handler/__init__.py
diff --git a/colossalai/engine/gradient_handler/_base_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_base_gradient_handler.py
similarity index 100%
rename from colossalai/engine/gradient_handler/_base_gradient_handler.py
rename to colossalai/legacy/engine/gradient_handler/_base_gradient_handler.py
diff --git a/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py
similarity index 94%
rename from colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py
rename to colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py
index 5cc7169c5..d0196e3c4 100644
--- a/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py
+++ b/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py
@@ -1,7 +1,7 @@
+from colossalai.context.parallel_mode import ParallelMode
 from colossalai.core import global_context as gpc
 from colossalai.registry import GRADIENT_HANDLER
 
-from ...context.parallel_mode import ParallelMode
 from ._base_gradient_handler import BaseGradientHandler
 from .utils import bucket_allreduce
 
diff --git a/colossalai/engine/gradient_handler/_moe_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py
similarity index 97%
rename from colossalai/engine/gradient_handler/_moe_gradient_handler.py
rename to colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py
index b499345d4..f2db95752 100644
--- a/colossalai/engine/gradient_handler/_moe_gradient_handler.py
+++ b/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py
@@ -1,9 +1,9 @@
 from colossalai.context.moe_context import MOE_CONTEXT
+from colossalai.context.parallel_mode import ParallelMode
 from colossalai.core import global_context as gpc
 from colossalai.registry import GRADIENT_HANDLER
 from colossalai.utils.moe import get_moe_epsize_param_dict
 
-from ...context.parallel_mode import ParallelMode
 from ._base_gradient_handler import BaseGradientHandler
 from .utils import bucket_allreduce
 
diff --git a/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py
similarity index 100%
rename from colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py
rename to colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py
diff --git a/colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py
similarity index 94%
rename from colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py
rename to colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py
index ea4f0fbb1..f13568094 100644
--- a/colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py
+++ b/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py
@@ -1,7 +1,7 @@
+from colossalai.context.parallel_mode import ParallelMode
 from colossalai.core import global_context as gpc
 from colossalai.registry import GRADIENT_HANDLER
 
-from ...context.parallel_mode import ParallelMode
 from ._base_gradient_handler import BaseGradientHandler
 from .utils import bucket_allreduce
 
diff --git a/colossalai/engine/gradient_handler/_zero_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_zero_gradient_handler.py
similarity index 100%
rename from colossalai/engine/gradient_handler/_zero_gradient_handler.py
rename to colossalai/legacy/engine/gradient_handler/_zero_gradient_handler.py
diff --git a/colossalai/engine/gradient_handler/utils.py b/colossalai/legacy/engine/gradient_handler/utils.py
similarity index 100%
rename from colossalai/engine/gradient_handler/utils.py
rename to colossalai/legacy/engine/gradient_handler/utils.py
diff --git a/colossalai/engine/schedule/__init__.py b/colossalai/legacy/engine/schedule/__init__.py
similarity index 100%
rename from colossalai/engine/schedule/__init__.py
rename to colossalai/legacy/engine/schedule/__init__.py
diff --git a/colossalai/engine/schedule/_base_schedule.py b/colossalai/legacy/engine/schedule/_base_schedule.py
similarity index 98%
rename from colossalai/engine/schedule/_base_schedule.py
rename to colossalai/legacy/engine/schedule/_base_schedule.py
index a2d500411..7505a3eb2 100644
--- a/colossalai/engine/schedule/_base_schedule.py
+++ b/colossalai/legacy/engine/schedule/_base_schedule.py
@@ -95,7 +95,7 @@ class BaseSchedule(ABC):
         """The process function over a batch of dataset for training or evaluation.
 
         Args:
-            engine (colossalai.engine.Engine): Colossalai engine for training and inference.
+            engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
             data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
             forward_only (bool): If True, the process won't include backward.
             return_loss (bool, optional): If False, the loss won't be returned.
diff --git a/colossalai/engine/schedule/_non_pipeline_schedule.py b/colossalai/legacy/engine/schedule/_non_pipeline_schedule.py
similarity index 97%
rename from colossalai/engine/schedule/_non_pipeline_schedule.py
rename to colossalai/legacy/engine/schedule/_non_pipeline_schedule.py
index b9239d928..b67893c1a 100644
--- a/colossalai/engine/schedule/_non_pipeline_schedule.py
+++ b/colossalai/legacy/engine/schedule/_non_pipeline_schedule.py
@@ -54,7 +54,7 @@ class NonPipelineSchedule(BaseSchedule):
         The returned labels and loss will None if :attr:`return_loss` is False.
 
         Args:
-            engine (colossalai.engine.Engine): Colossalai engine for training and inference.
+            engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
             data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
             forward_only (bool, optional):
                 If True, the model is run for the forward pass, else back propagation will be executed.
diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/legacy/engine/schedule/_pipeline_schedule.py
similarity index 98%
rename from colossalai/engine/schedule/_pipeline_schedule.py
rename to colossalai/legacy/engine/schedule/_pipeline_schedule.py
index 9fc301a26..88b54ce6a 100644
--- a/colossalai/engine/schedule/_pipeline_schedule.py
+++ b/colossalai/legacy/engine/schedule/_pipeline_schedule.py
@@ -236,7 +236,7 @@ class PipelineSchedule(BaseSchedule):
         Returns output tensor. This is a helper function and can be ignored by users.
 
         Args:
-            engine (colossalai.engine.Engine): Colossalai engine for training and inference.
+            engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
             input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
             return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
             return_output_label (bool, optional): Whether returns output labels.
@@ -274,7 +274,7 @@ class PipelineSchedule(BaseSchedule):
         This is a helper function and can be ignored by users.
 
         Args:
-            engine (colossalai.engine.Engine): Colossalai engine for training and inference.
+            engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
             input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): input tensor for this pipeline stage.
             output_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): output tensor for this pipeline stage.
             output_obj_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): gradient of output tensor for this pipeline stage.
@@ -314,7 +314,7 @@ class PipelineSchedule(BaseSchedule):
         Returns a tuple with losses if the last stage, an empty tuple otherwise.
 
         Args:
-            engine (colossalai.engine.Engine): Colossalai engine for training and inference.
+            engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
             data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
             forward_only (bool, optional):
                 Whether run forward step only. Default is false. If true, no backward will be run.
@@ -518,7 +518,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
         Returns output tensor. This is a helper function and can be ignored by users.
 
         Args:
-            engine (colossalai.engine.Engine): Colossalai engine for training and inference.
+            engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
             model_chunk_id (int): The id of model chunks.
             input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
             return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
@@ -555,7 +555,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
         communication between pipeline stages as needed.
 
         Args:
-            engine (colossalai.engine.Engine): Colossalai engine for training and inference.
+            engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
             data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
             forward_only (bool, optional):
                 Whether run forward step only. Default is false. If true, no backward will be run.
diff --git a/colossalai/engine/schedule/_pipeline_schedule_v2.py b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py
similarity index 98%
rename from colossalai/engine/schedule/_pipeline_schedule_v2.py
rename to colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py
index 89e45c7aa..9e7372b67 100644
--- a/colossalai/engine/schedule/_pipeline_schedule_v2.py
+++ b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py
@@ -69,7 +69,7 @@ class PipelineScheduleV2(PipelineSchedule):
         Returns a tuple with losses if the last stage, an empty tuple otherwise.
 
         Args:
-            engine (colossalai.engine.Engine): Colossalai engine for training and inference.
+            engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
             data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
             forward_only (bool, optional):
                 Whether run forward step only. Default is false. If true, no backward will be run.
diff --git a/colossalai/legacy/trainer/_trainer.py b/colossalai/legacy/trainer/_trainer.py
index fb66acec5..1847e5622 100644
--- a/colossalai/legacy/trainer/_trainer.py
+++ b/colossalai/legacy/trainer/_trainer.py
@@ -4,7 +4,7 @@ import torch
 from torch.utils.data import DataLoader
 from tqdm import tqdm
 
-from colossalai.engine import Engine
+from colossalai.legacy.engine import Engine
 from colossalai.legacy.trainer.hooks import BaseHook
 from colossalai.logging import DistributedLogger
 from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0
diff --git a/colossalai/utils/profiler/profiler.py b/colossalai/utils/profiler/profiler.py
index 8f43a0b96..3026d723d 100644
--- a/colossalai/utils/profiler/profiler.py
+++ b/colossalai/utils/profiler/profiler.py
@@ -1,17 +1,17 @@
-import os
-from typing import List
-from colossalai.engine import Engine
-from torch.profiler import profile as torch_profile
-from torch.profiler.profiler import ProfilerAction
-from typing import Any, Callable, Iterable, Optional
-from torch.autograd import ProfilerActivity
+import gzip
 import json
 import os
 import tempfile
-import gzip
+from typing import Any, Callable, Iterable, List, Optional
+
+from torch.autograd import ProfilerActivity
+from torch.profiler import profile as torch_profile
+from torch.profiler.profiler import ProfilerAction
+
+from colossalai.legacy.engine import Engine
+from colossalai.logging import get_dist_logger
 from colossalai.utils.profiler.extention import ProfilerExtension
 from colossalai.utils.profiler.stateful_tensor_mem_extention import StatefulTensorMemoryProfilerExtention
-from colossalai.logging import get_dist_logger
 
 
 class profile(torch_profile):
diff --git a/colossalai/utils/profiler/stateful_tensor_mem_extention.py b/colossalai/utils/profiler/stateful_tensor_mem_extention.py
index 127055c8c..412bd7277 100644
--- a/colossalai/utils/profiler/stateful_tensor_mem_extention.py
+++ b/colossalai/utils/profiler/stateful_tensor_mem_extention.py
@@ -1,12 +1,14 @@
 import os
 import threading
 import time
-import torch
 from enum import Enum
 from typing import List
-from colossalai.gemini.stateful_tensor import StatefulTensor
+
+import torch
+
 from colossalai.gemini.ophooks import BaseOpHook
-from colossalai.engine import Engine
+from colossalai.gemini.stateful_tensor import StatefulTensor
+from colossalai.legacy.engine import Engine
 from colossalai.utils.profiler.extention import ProfilerExtension
 
 
diff --git a/docs/source/en/advanced_tutorials/add_your_parallel.md b/docs/source/en/advanced_tutorials/add_your_parallel.md
index 1caf58c87..cda49af47 100644
--- a/docs/source/en/advanced_tutorials/add_your_parallel.md
+++ b/docs/source/en/advanced_tutorials/add_your_parallel.md
@@ -92,14 +92,14 @@ follow the steps below to create a new distributed initialization.
 
 Gradient handlers are objects which execute the all-reduce operations on parameters' gradients. As different all-reduce
 strategies may be executed for different kinds of parallelism, users can
-inherit `colossalai.engine.gradient_handler.BaseGradientHandler` to implement their strategies. Currently, the library
+inherit `colossalai.legacy.engine.gradient_handler.BaseGradientHandler` to implement their strategies. Currently, the library
 uses the normal data parallel gradient handler which all-reduces the gradients across data parallel ranks. The data
 parallel gradient handler is added to the engine automatically if data parallel is detected. You can add your own
 gradient handler like below:
 
 ```python
 from colossalai.registry import GRADIENT_HANDLER
-from colossalai.engine import BaseGradientHandler
+from colossalai.legacy.engine import BaseGradientHandler
 
 @GRADIENT_HANDLER.register_module
 class YourGradientHandler(BaseGradientHandler):
@@ -121,4 +121,5 @@ gradient_handlers = [
 
 Schedule entails how to execute a forward and backward pass. Currently, Colossal-AI provides pipeline and non-pipeline
 schedules. If you want to modify how the forward and backward passes are executed, you can
-inherit `colossalai.engine.schedule.BaseSchedule` and implement the `forward_back_step` function.
+inherit `colossalai.legacy.engine.schedule.BaseSchedule` and implement the `forward_back_step` function.
+<!-- doc-test-command: echo  -->
diff --git a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
index 24aa2610f..98c16e922 100644
--- a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
+++ b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
@@ -39,7 +39,7 @@ from colossalai.amp import AMP_TYPE
 from colossalai.builder.pipeline import partition_uniform
 from colossalai.context.parallel_mode import ParallelMode
 from colossalai.core import global_context as gpc
-from colossalai.engine.schedule import (InterleavedPipelineSchedule,
+from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
                                         PipelineSchedule)
 from colossalai.logging import disable_existing_loggers, get_dist_logger
 from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
diff --git a/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md b/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md
index 3475d8f07..370931d87 100644
--- a/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md
+++ b/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md
@@ -35,7 +35,7 @@ import colossalai.nn as col_nn
 import torch
 import torch.nn as nn
 from colossalai.builder import build_pipeline_model
-from colossalai.engine.schedule import (InterleavedPipelineSchedule,
+from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
                                         PipelineSchedule)
 from colossalai.logging import disable_existing_loggers, get_dist_logger
 from colossalai.legacy.trainer import Trainer, hooks
diff --git a/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md
index 5b0b694b3..fc1101c5a 100644
--- a/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md
+++ b/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md
@@ -415,7 +415,7 @@ def build_pipeline_vit(num_layers, num_chunks, device=torch.device('cuda'), **kw
 
 #### Import modules
 ```python
-from colossalai.engine.schedule import (InterleavedPipelineSchedule,
+from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
                                         PipelineSchedule)
 from colossalai.utils import MultiTimer
 import os
diff --git a/docs/source/en/features/gradient_handler.md b/docs/source/en/features/gradient_handler.md
index 757016fcb..14ced32b8 100644
--- a/docs/source/en/features/gradient_handler.md
+++ b/docs/source/en/features/gradient_handler.md
@@ -29,7 +29,7 @@ To implement a customized gradient handler, you need to follow these steps.
 
 ```python
 from colossalai.registry import GRADIENT_HANDLER
-from colossalai.engine.gradient_handler import BaseGradientHandler
+from colossalai.legacy.engine.gradient_handler import BaseGradientHandler
 
 
 @GRADIENT_HANDLER.register_module
@@ -61,3 +61,4 @@ to demonstrate the use of gradient handler. In this example, we used `DataParall
 ```shell
 python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500  train_with_engine.py
 ```
+<!-- doc-test-command: echo  -->
diff --git a/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md b/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md
index 059eb014a..abfe058c6 100644
--- a/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md
+++ b/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md
@@ -81,14 +81,14 @@ Colossal-AI 为用户提供了一个全局 context,使他们能够轻松地管
 ## 梯度 Handler
 
 梯度 handler 是对参数的梯度执行 all-reduce 操作的对象。由于不同的 all-reduce 策略或许在不同的并行中被执行,用户可以继承
-`colossalai.engine.gradient_handler.BaseGradientHandler` 来实现其策略。目前,Colossal-AI 使用普通的数据并行梯度 handler 在数据并行的 rank 间 all-reduce 梯度。
+`colossalai.legacy.engine.gradient_handler.BaseGradientHandler` 来实现其策略。目前,Colossal-AI 使用普通的数据并行梯度 handler 在数据并行的 rank 间 all-reduce 梯度。
 如果数据并行被检测到,梯度 handler 会被自动添加进 engine。
 
 你可以添加你自己的梯度 handler,如下所示:
 
 ```python
 from colossalai.registry import GRADIENT_HANDLER
-from colossalai.engine import BaseGradientHandler
+from colossalai.legacy.engine import BaseGradientHandler
 
 @GRADIENT_HANDLER.register_module
 class YourGradientHandler(BaseGradientHandler):
@@ -109,4 +109,5 @@ gradient_handlers = [
 ## Schedule
 
 Schedule 包含了如何执行前向和后向计算。目前, Colossal-AI 提供了流水和非流水的 schedule。
-如果你想修改前向和后向计算的执行方式,你可以继承 `colossalai.engine.schedule.BaseSchedule` 并实现 `forward_back_step` 函数。
+如果你想修改前向和后向计算的执行方式,你可以继承 `colossalai.legacy.engine.schedule.BaseSchedule` 并实现 `forward_back_step` 函数。
+<!-- doc-test-command: echo  -->
diff --git a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
index a199d31e7..84b48165b 100644
--- a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
+++ b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
@@ -39,7 +39,7 @@ from colossalai.amp import AMP_TYPE
 from colossalai.builder.pipeline import partition_uniform
 from colossalai.context.parallel_mode import ParallelMode
 from colossalai.core import global_context as gpc
-from colossalai.engine.schedule import (InterleavedPipelineSchedule,
+from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
                                         PipelineSchedule)
 from colossalai.logging import disable_existing_loggers, get_dist_logger
 from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
diff --git a/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md
index d3a98c89b..1ac01c207 100644
--- a/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md
+++ b/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md
@@ -33,7 +33,7 @@ import colossalai.nn as col_nn
 import torch
 import torch.nn as nn
 from colossalai.builder import build_pipeline_model
-from colossalai.engine.schedule import (InterleavedPipelineSchedule,
+from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
                                         PipelineSchedule)
 from colossalai.logging import disable_existing_loggers, get_dist_logger
 from colossalai.legacy.trainer import Trainer, hooks
diff --git a/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md
index ddc2502f0..650bab105 100644
--- a/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md
+++ b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md
@@ -380,7 +380,7 @@ def build_pipeline_vit(num_layers, num_chunks, device=torch.device('cuda'), **kw
 
 #### 导入模块
 ```python
-from colossalai.engine.schedule import (InterleavedPipelineSchedule,
+from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
                                         PipelineSchedule)
 from colossalai.utils import MultiTimer
 import os
diff --git a/docs/source/zh-Hans/features/gradient_handler.md b/docs/source/zh-Hans/features/gradient_handler.md
index 701c60fed..b08dd6806 100644
--- a/docs/source/zh-Hans/features/gradient_handler.md
+++ b/docs/source/zh-Hans/features/gradient_handler.md
@@ -26,7 +26,7 @@
 
 ```python
 from colossalai.registry import GRADIENT_HANDLER
-from colossalai.engine.gradient_handler import BaseGradientHandler
+from colossalai.legacy.engine.gradient_handler import BaseGradientHandler
 
 
 @GRADIENT_HANDLER.register_module
@@ -57,3 +57,4 @@ gradient_handler = [dict(type='MyGradientHandler')]
 ```shell
 python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500  train_with_engine.py
 ```
+<!-- doc-test-command: echo  -->
diff --git a/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py
index b4febcd82..9a25dc453 100644
--- a/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py
+++ b/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py
@@ -3,17 +3,16 @@
 # This source code is licensed under the MIT license found in the
 # LICENSE file in the root directory of this source tree.
 
-
 # copied from fairseq/fairseq/data/indexed_dataset.py
 # Removed IndexedRawTextDataset since it relied on Fairseq dictionary
 # other slight modifications to remove fairseq dependencies
 # Added document index to index file and made it accessible.
 #    An empty sentence no longer separates documents.
 
-from functools import lru_cache
 import os
 import shutil
 import struct
+from functools import lru_cache
 from itertools import accumulate
 
 import numpy as np
@@ -88,16 +87,7 @@ def write_longs(f, a):
     f.write(np.array(a, dtype=np.int64))
 
 
-dtypes = {
-    1: np.uint8,
-    2: np.int8,
-    3: np.int16,
-    4: np.int32,
-    5: np.int64,
-    6: np.float,
-    7: np.double,
-    8: np.uint16
-}
+dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: float, 7: np.double, 8: np.uint16}
 
 
 def code(dtype):
@@ -136,10 +126,8 @@ class IndexedDataset(torch.utils.data.Dataset):
     def read_index(self, path):
         with open(index_file_path(path), 'rb') as f:
             magic = f.read(8)
-            assert magic == self._HDR_MAGIC, (
-                'Index file doesn\'t match expected format. '
-                'Make sure that --dataset-impl is configured properly.'
-            )
+            assert magic == self._HDR_MAGIC, ('Index file doesn\'t match expected format. '
+                                              'Make sure that --dataset-impl is configured properly.')
             version = f.read(8)
             assert struct.unpack('<Q', version) == (1,)
             code, self.element_size = struct.unpack('<QQ', f.read(16))
@@ -198,13 +186,11 @@ class IndexedDataset(torch.utils.data.Dataset):
 
     @staticmethod
     def exists(path):
-        return (
-            os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
-        )
+        return (os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)))
 
     @property
     def supports_prefetch(self):
-        return False  # avoid prefetching to save memory
+        return False    # avoid prefetching to save memory
 
 
 class IndexedCachedDataset(IndexedDataset):
@@ -233,7 +219,7 @@ class IndexedCachedDataset(IndexedDataset):
         for i in indices:
             self.cache_index[i] = ptx
             size = self.data_offsets[i + 1] - self.data_offsets[i]
-            a = self.cache[ptx: ptx + size]
+            a = self.cache[ptx:ptx + size]
             self.data_file.seek(self.data_offsets[i] * self.element_size)
             self.data_file.readinto(a)
             ptx += size
@@ -250,7 +236,7 @@ class IndexedCachedDataset(IndexedDataset):
             tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
             a = np.empty(tensor_size, dtype=self.dtype)
             ptx = self.cache_index[i]
-            np.copyto(a, self.cache[ptx: ptx + a.size])
+            np.copyto(a, self.cache[ptx:ptx + a.size])
             return a
         elif isinstance(idx, slice):
             # Hack just to make this work, can optimizer later if necessary
@@ -261,15 +247,7 @@ class IndexedCachedDataset(IndexedDataset):
 
 
 class IndexedDatasetBuilder(object):
-    element_sizes = {
-        np.uint8: 1,
-        np.int8: 1,
-        np.int16: 2,
-        np.int32: 4,
-        np.int64: 8,
-        np.float: 4,
-        np.double: 8
-    }
+    element_sizes = {np.uint8: 1, np.int8: 1, np.int16: 2, np.int32: 4, np.int64: 8, float: 4, np.double: 8}
 
     def __init__(self, out_file, dtype=np.int32):
         self.out_file = open(out_file, 'wb')
@@ -332,12 +310,15 @@ def _warmup_mmap_file(path):
 
 
 class MMapIndexedDataset(torch.utils.data.Dataset):
+
     class Index(object):
         _HDR_MAGIC = b'MMIDIDX\x00\x00'
 
         @classmethod
         def writer(cls, path, dtype):
+
             class _Writer(object):
+
                 def __enter__(self):
                     self._file = open(path, 'wb')
 
@@ -384,10 +365,8 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
         def __init__(self, path, skip_warmup=False):
             with open(path, 'rb') as stream:
                 magic_test = stream.read(9)
-                assert self._HDR_MAGIC == magic_test, (
-                    'Index file doesn\'t match expected format. '
-                    'Make sure that --dataset-impl is configured properly.'
-                )
+                assert self._HDR_MAGIC == magic_test, ('Index file doesn\'t match expected format. '
+                                                       'Make sure that --dataset-impl is configured properly.')
                 version = struct.unpack('<Q', stream.read(8))
                 assert (1,) == version
 
@@ -406,16 +385,16 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
             self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
             self._bin_buffer = memoryview(self._bin_buffer_mmap)
             print("    reading sizes...")
-            self._sizes = np.frombuffer(
-                self._bin_buffer,
-                dtype=np.int32,
-                count=self._len,
-                offset=offset)
+            self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset)
             print("    reading pointers...")
-            self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
+            self._pointers = np.frombuffer(self._bin_buffer,
+                                           dtype=np.int64,
+                                           count=self._len,
                                            offset=offset + self._sizes.nbytes)
             print("    reading document index...")
-            self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
+            self._doc_idx = np.frombuffer(self._bin_buffer,
+                                          dtype=np.int64,
+                                          count=self._doc_count,
                                           offset=offset + self._sizes.nbytes + self._pointers.nbytes)
 
         def __del__(self):
@@ -480,8 +459,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
     def __getitem__(self, idx):
         if isinstance(idx, int):
             ptr, size = self._index[idx]
-            np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
-                                     count=size, offset=ptr)
+            np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr)
             return np_array
         elif isinstance(idx, slice):
             start, stop, step = idx.indices(len(self))
@@ -491,8 +469,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
             sizes = self._index._sizes[idx]
             offsets = list(accumulate(sizes))
             total_size = sum(sizes)
-            np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
-                                     count=total_size, offset=ptr)
+            np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr)
             sents = np.split(np_array, offsets[:-1])
             return sents
 
@@ -506,8 +483,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
         if length is None:
             length = size - offset
         ptr += offset * np.dtype(self._index.dtype).itemsize
-        np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
-                                 count=length, offset=ptr)
+        np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr)
         return np_array
 
     @property
@@ -530,12 +506,11 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
 
     @staticmethod
     def exists(path):
-        return (
-            os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
-        )
+        return (os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)))
 
 
 class MMapIndexedDatasetBuilder(object):
+
     def __init__(self, out_file, dtype=np.int64):
         self._data_file = open(out_file, 'wb')
         self._dtype = dtype
diff --git a/examples/tutorial/sequence_parallel/requirements.txt b/examples/tutorial/sequence_parallel/requirements.txt
index b49a94554..4fc576453 100644
--- a/examples/tutorial/sequence_parallel/requirements.txt
+++ b/examples/tutorial/sequence_parallel/requirements.txt
@@ -1,2 +1,3 @@
 colossalai
 torch
+six
diff --git a/examples/tutorial/sequence_parallel/train.py b/examples/tutorial/sequence_parallel/train.py
index a89747b58..86c4edeb5 100644
--- a/examples/tutorial/sequence_parallel/train.py
+++ b/examples/tutorial/sequence_parallel/train.py
@@ -11,8 +11,8 @@ import colossalai
 from colossalai.amp import AMP_TYPE
 from colossalai.context.parallel_mode import ParallelMode
 from colossalai.core import global_context as gpc
-from colossalai.engine.schedule import PipelineSchedule
 from colossalai.kernel import LayerNorm
+from colossalai.legacy.engine.schedule import PipelineSchedule
 from colossalai.logging import get_dist_logger
 from colossalai.nn.optimizer import FusedAdam
 from colossalai.utils import MultiTimer, is_using_pp
diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py
index 4fc67bd29..23561f8ae 100644
--- a/tests/test_booster/test_plugin/test_gemini_plugin.py
+++ b/tests/test_booster/test_plugin/test_gemini_plugin.py
@@ -98,7 +98,7 @@ def check_gemini_plugin(subset: str, init_method: str = 'none', early_stop: bool
         ]:
             continue
         err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn)
-
+        torch.cuda.empty_cache()
         if err is None:
             passed_models.append(name)
         else:
diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py
index e7002a75f..9c84a99cd 100644
--- a/tests/test_moe/test_grad_handler.py
+++ b/tests/test_moe/test_grad_handler.py
@@ -5,7 +5,7 @@ import torch.nn as nn
 
 import colossalai
 from colossalai.context.moe_context import MOE_CONTEXT
-from colossalai.engine.gradient_handler import MoeGradientHandler
+from colossalai.legacy.engine.gradient_handler import MoeGradientHandler
 from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, UniformNoiseGenerator
 from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
 from colossalai.utils import get_current_device
diff --git a/tests/test_moe/test_moe_zero_model.py b/tests/test_moe/test_moe_zero_model.py
index ec37967f1..595d4374d 100644
--- a/tests/test_moe/test_moe_zero_model.py
+++ b/tests/test_moe/test_moe_zero_model.py
@@ -3,7 +3,7 @@ import torch
 
 import colossalai
 from colossalai.context import MOE_CONTEXT
-from colossalai.engine.gradient_handler import MoeGradientHandler
+from colossalai.legacy.engine.gradient_handler import MoeGradientHandler
 from colossalai.nn import MoeLoss
 from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn
 from colossalai.zero.legacy.init_ctx import ZeroInitContext
diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py
index efc6e9dda..a43ae764d 100644
--- a/tests/test_moe/test_moe_zero_optim.py
+++ b/tests/test_moe/test_moe_zero_optim.py
@@ -4,7 +4,7 @@ import torch
 import colossalai
 from colossalai.amp import convert_to_apex_amp
 from colossalai.context import MOE_CONTEXT
-from colossalai.engine.gradient_handler import MoeGradientHandler
+from colossalai.legacy.engine.gradient_handler import MoeGradientHandler
 from colossalai.nn import MoeLoss
 from colossalai.nn.optimizer import CPUAdam
 from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn

From ac178ca5c17ca751ff9df38be81da2b4a005fc0d Mon Sep 17 00:00:00 2001
From: Hongxin Liu <lhx0217@gmail.com>
Date: Mon, 4 Sep 2023 19:56:42 +0800
Subject: [PATCH 4/4] [legacy] move builder and registry to legacy (#4603)

---
 .../tensor_shard/node_handler/registry.py     |   2 +-
 colossalai/context/parallel_context.py        |   2 +-
 .../initializer_1d.py                         |   3 +-
 .../initializer_2d.py                         |   2 +-
 .../initializer_2p5d.py                       |   3 +-
 .../initializer_3d.py                         |   2 +-
 .../initializer_data.py                       |   2 +-
 .../initializer_model.py                      |   6 +-
 .../initializer_pipeline.py                   |   2 +-
 .../initializer_sequence.py                   |   2 +-
 .../initializer_tensor.py                     |   5 +-
 colossalai/initialize.py                      |   2 +-
 colossalai/{ => legacy}/builder/__init__.py   |   0
 colossalai/{ => legacy}/builder/builder.py    |   2 +-
 .../_data_parallel_gradient_handler.py        |   2 +-
 .../gradient_handler/_moe_gradient_handler.py |   2 +-
 .../_pipeline_parallel_gradient_handler.py    |   2 +-
 .../_sequence_parallel_gradient_handler.py    |   2 +-
 .../_zero_gradient_handler.py                 |   2 +-
 colossalai/{ => legacy}/registry/__init__.py  |   0
 colossalai/{ => legacy}/registry/registry.py  |   4 +-
 .../legacy/trainer/hooks/_checkpoint_hook.py  |   2 +-
 colossalai/legacy/trainer/hooks/_log_hook.py  |   2 +-
 .../trainer/hooks/_lr_scheduler_hook.py       |   2 +-
 .../legacy/trainer/hooks/_metric_hook.py      |   6 +-
 colossalai/nn/layer/parallel_1d/layers.py     |   2 +-
 colossalai/nn/layer/parallel_2d/layers.py     |  19 +-
 colossalai/nn/layer/parallel_2p5d/layers.py   |  26 ++-
 colossalai/nn/layer/parallel_3d/layers.py     |   2 +-
 .../nn/layer/parallel_sequence/layers.py      |  10 +-
 colossalai/nn/layer/vanilla/layers.py         |   2 +-
 colossalai/nn/loss/loss_1d.py                 | 211 +++++++++---------
 colossalai/nn/loss/loss_2d.py                 |  13 +-
 colossalai/nn/loss/loss_2p5d.py               |  13 +-
 colossalai/nn/loss/loss_3d.py                 |  13 +-
 colossalai/nn/loss/loss_moe.py                | 161 ++++++-------
 colossalai/nn/lr_scheduler/cosine.py          |   3 +-
 colossalai/nn/lr_scheduler/linear.py          |   2 +-
 colossalai/nn/lr_scheduler/multistep.py       |   3 +-
 colossalai/nn/lr_scheduler/onecycle.py        |   2 +-
 colossalai/nn/lr_scheduler/poly.py            |   3 +-
 colossalai/nn/lr_scheduler/torch.py           |   4 +-
 colossalai/nn/optimizer/cpu_adam.py           |   2 +-
 colossalai/nn/optimizer/fused_adam.py         |   2 +-
 colossalai/nn/optimizer/fused_lamb.py         |   2 +-
 colossalai/nn/optimizer/fused_sgd.py          |   2 +-
 colossalai/nn/optimizer/hybrid_adam.py        |   2 +-
 colossalai/nn/optimizer/lamb.py               |   2 +-
 colossalai/nn/optimizer/lars.py               |  35 ++-
 .../data_sampler/data_parallel_sampler.py     |  26 +--
 .../gemini/ophooks/_shard_grad_ophook.py      |   2 +-
 .../gemini/ophooks/_shard_param_ophook.py     |   2 +-
 .../zero/legacy/sharded_model/zero_hook.py    |   2 +-
 .../advanced_tutorials/add_your_parallel.md   |   2 +-
 .../train_gpt_using_hybrid_parallelism.md     |   2 +-
 .../train_vit_using_pipeline_parallelism.md   |  12 +-
 .../train_vit_with_hybrid_parallelism.md      |   8 +-
 docs/source/en/features/gradient_handler.md   |   2 +-
 .../advanced_tutorials/add_your_parallel.md   |   2 +-
 .../train_gpt_using_hybrid_parallelism.md     |   2 +-
 .../train_vit_using_pipeline_parallelism.md   |  12 +-
 .../train_vit_with_hybrid_parallelism.md      |   8 +-
 .../zh-Hans/features/gradient_handler.md      |   2 +-
 .../language/gpt/titans/dataset/webtext.py    |   2 +-
 examples/language/gpt/titans/model/embed.py   |   2 +-
 65 files changed, 353 insertions(+), 332 deletions(-)
 rename colossalai/{ => legacy}/builder/__init__.py (100%)
 rename colossalai/{ => legacy}/builder/builder.py (98%)
 rename colossalai/{ => legacy}/registry/__init__.py (100%)
 rename colossalai/{ => legacy}/registry/registry.py (98%)

diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/registry.py b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py
index 8e06cec4f..1a90c72bd 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/registry.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py
@@ -1,5 +1,5 @@
 class Registry:
-    # TODO: refactor the registry classes used in colossalai.registry, colossalai.fx and here
+    # TODO: refactor the registry classes used in colossalai.legacy.registry, colossalai.fx and here
 
     def __init__(self, name):
         self.name = name
diff --git a/colossalai/context/parallel_context.py b/colossalai/context/parallel_context.py
index 003f0cdd9..7186f052e 100644
--- a/colossalai/context/parallel_context.py
+++ b/colossalai/context/parallel_context.py
@@ -15,8 +15,8 @@ from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
 from colossalai.context.config import Config
 from colossalai.context.singleton_meta import SingletonMeta
 from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
 from colossalai.logging import get_dist_logger
-from colossalai.registry import DIST_GROUP_INITIALIZER
 
 from .parallel_mode import ParallelMode
 from .random import add_seed, get_seeds, set_mode
diff --git a/colossalai/context/process_group_initializer/initializer_1d.py b/colossalai/context/process_group_initializer/initializer_1d.py
index 4c0502804..ba601d0bf 100644
--- a/colossalai/context/process_group_initializer/initializer_1d.py
+++ b/colossalai/context/process_group_initializer/initializer_1d.py
@@ -2,8 +2,9 @@
 # -*- encoding: utf-8 -*-
 
 import torch.distributed as dist
+
 from colossalai.global_variables import tensor_parallel_env as env
-from colossalai.registry import DIST_GROUP_INITIALIZER
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
 
 from ..parallel_mode import ParallelMode
 from .process_group_initializer import ProcessGroupInitializer
diff --git a/colossalai/context/process_group_initializer/initializer_2d.py b/colossalai/context/process_group_initializer/initializer_2d.py
index 7fbe3be59..999cd5f0c 100644
--- a/colossalai/context/process_group_initializer/initializer_2d.py
+++ b/colossalai/context/process_group_initializer/initializer_2d.py
@@ -3,7 +3,7 @@ import math
 import torch.distributed as dist
 
 from colossalai.global_variables import tensor_parallel_env as env
-from colossalai.registry import DIST_GROUP_INITIALIZER
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
 
 from ..parallel_mode import ParallelMode
 from .process_group_initializer import ProcessGroupInitializer
diff --git a/colossalai/context/process_group_initializer/initializer_2p5d.py b/colossalai/context/process_group_initializer/initializer_2p5d.py
index 6b6fdc5d7..b92ae2eec 100644
--- a/colossalai/context/process_group_initializer/initializer_2p5d.py
+++ b/colossalai/context/process_group_initializer/initializer_2p5d.py
@@ -4,9 +4,10 @@
 import math
 
 import torch.distributed as dist
+
 from colossalai.context import Config
 from colossalai.global_variables import tensor_parallel_env as env
-from colossalai.registry import DIST_GROUP_INITIALIZER
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
 
 from ..parallel_mode import ParallelMode
 from .process_group_initializer import ProcessGroupInitializer
diff --git a/colossalai/context/process_group_initializer/initializer_3d.py b/colossalai/context/process_group_initializer/initializer_3d.py
index 1ed8eec86..6bca05ad7 100644
--- a/colossalai/context/process_group_initializer/initializer_3d.py
+++ b/colossalai/context/process_group_initializer/initializer_3d.py
@@ -6,7 +6,7 @@ import math
 import torch.distributed as dist
 
 from colossalai.global_variables import tensor_parallel_env as env
-from colossalai.registry import DIST_GROUP_INITIALIZER
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
 
 from ..parallel_mode import ParallelMode
 from .process_group_initializer import ProcessGroupInitializer
diff --git a/colossalai/context/process_group_initializer/initializer_data.py b/colossalai/context/process_group_initializer/initializer_data.py
index 9715ebff7..b9dec4541 100644
--- a/colossalai/context/process_group_initializer/initializer_data.py
+++ b/colossalai/context/process_group_initializer/initializer_data.py
@@ -3,7 +3,7 @@
 
 from torch import distributed as dist
 
-from colossalai.registry import DIST_GROUP_INITIALIZER
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
 
 from ..parallel_mode import ParallelMode
 from .process_group_initializer import ProcessGroupInitializer
diff --git a/colossalai/context/process_group_initializer/initializer_model.py b/colossalai/context/process_group_initializer/initializer_model.py
index 99b9cc0d4..614ba372f 100644
--- a/colossalai/context/process_group_initializer/initializer_model.py
+++ b/colossalai/context/process_group_initializer/initializer_model.py
@@ -2,9 +2,11 @@
 # -*- encoding: utf-8 -*-
 
 import torch.distributed as dist
-from colossalai.registry import DIST_GROUP_INITIALIZER
-from .process_group_initializer import ProcessGroupInitializer
+
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
+
 from ..parallel_mode import ParallelMode
+from .process_group_initializer import ProcessGroupInitializer
 
 
 @DIST_GROUP_INITIALIZER.register_module
diff --git a/colossalai/context/process_group_initializer/initializer_pipeline.py b/colossalai/context/process_group_initializer/initializer_pipeline.py
index 0ddb52f63..e093333ad 100644
--- a/colossalai/context/process_group_initializer/initializer_pipeline.py
+++ b/colossalai/context/process_group_initializer/initializer_pipeline.py
@@ -3,7 +3,7 @@
 
 from torch import distributed as dist
 
-from colossalai.registry import DIST_GROUP_INITIALIZER
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
 
 from ..parallel_mode import ParallelMode
 from .process_group_initializer import ProcessGroupInitializer
diff --git a/colossalai/context/process_group_initializer/initializer_sequence.py b/colossalai/context/process_group_initializer/initializer_sequence.py
index 251a29407..a6e26b6bc 100644
--- a/colossalai/context/process_group_initializer/initializer_sequence.py
+++ b/colossalai/context/process_group_initializer/initializer_sequence.py
@@ -2,7 +2,7 @@
 # -*- encoding: utf-8 -*-
 import torch.distributed as dist
 
-from colossalai.registry import DIST_GROUP_INITIALIZER
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
 
 from ..parallel_mode import ParallelMode
 from .initializer_tensor import Initializer_Tensor
diff --git a/colossalai/context/process_group_initializer/initializer_tensor.py b/colossalai/context/process_group_initializer/initializer_tensor.py
index d2b5be9cf..3be89e52a 100644
--- a/colossalai/context/process_group_initializer/initializer_tensor.py
+++ b/colossalai/context/process_group_initializer/initializer_tensor.py
@@ -3,9 +3,10 @@
 
 import torch.distributed as dist
 
-from colossalai.registry import DIST_GROUP_INITIALIZER
-from .process_group_initializer import ProcessGroupInitializer
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
+
 from ..parallel_mode import ParallelMode
+from .process_group_initializer import ProcessGroupInitializer
 
 
 @DIST_GROUP_INITIALIZER.register_module
diff --git a/colossalai/initialize.py b/colossalai/initialize.py
index 32354dde8..a1694e059 100644
--- a/colossalai/initialize.py
+++ b/colossalai/initialize.py
@@ -17,10 +17,10 @@ from torch.utils.data import DataLoader
 
 from colossalai.amp import AMP_TYPE, convert_to_amp
 from colossalai.amp.naive_amp import NaiveAMPModel
-from colossalai.builder.builder import build_gradient_handler
 from colossalai.context import Config, ConfigException, ParallelMode
 from colossalai.context.moe_context import MOE_CONTEXT
 from colossalai.core import global_context as gpc
+from colossalai.legacy.builder.builder import build_gradient_handler
 from colossalai.legacy.engine import Engine
 from colossalai.legacy.engine.gradient_accumulation import accumulate_gradient
 from colossalai.legacy.engine.schedule import (
diff --git a/colossalai/builder/__init__.py b/colossalai/legacy/builder/__init__.py
similarity index 100%
rename from colossalai/builder/__init__.py
rename to colossalai/legacy/builder/__init__.py
diff --git a/colossalai/builder/builder.py b/colossalai/legacy/builder/builder.py
similarity index 98%
rename from colossalai/builder/builder.py
rename to colossalai/legacy/builder/builder.py
index a14509392..ff14f46dc 100644
--- a/colossalai/builder/builder.py
+++ b/colossalai/legacy/builder/builder.py
@@ -3,7 +3,7 @@
 
 import inspect
 
-from colossalai.registry import *
+from colossalai.legacy.registry import *
 
 
 def build_from_config(module, config: dict):
diff --git a/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py
index d0196e3c4..c5da2e55a 100644
--- a/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py
+++ b/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py
@@ -1,6 +1,6 @@
 from colossalai.context.parallel_mode import ParallelMode
 from colossalai.core import global_context as gpc
-from colossalai.registry import GRADIENT_HANDLER
+from colossalai.legacy.registry import GRADIENT_HANDLER
 
 from ._base_gradient_handler import BaseGradientHandler
 from .utils import bucket_allreduce
diff --git a/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py
index f2db95752..395d83da0 100644
--- a/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py
+++ b/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py
@@ -1,7 +1,7 @@
 from colossalai.context.moe_context import MOE_CONTEXT
 from colossalai.context.parallel_mode import ParallelMode
 from colossalai.core import global_context as gpc
-from colossalai.registry import GRADIENT_HANDLER
+from colossalai.legacy.registry import GRADIENT_HANDLER
 from colossalai.utils.moe import get_moe_epsize_param_dict
 
 from ._base_gradient_handler import BaseGradientHandler
diff --git a/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py
index 5b49a9c03..7d4d9d73a 100644
--- a/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py
+++ b/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py
@@ -7,7 +7,7 @@ import torch.distributed as dist
 from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
 
 from colossalai.core import global_context as gpc
-from colossalai.registry import GRADIENT_HANDLER
+from colossalai.legacy.registry import GRADIENT_HANDLER
 
 from ._base_gradient_handler import BaseGradientHandler
 
diff --git a/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py
index f13568094..41098ab39 100644
--- a/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py
+++ b/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py
@@ -1,6 +1,6 @@
 from colossalai.context.parallel_mode import ParallelMode
 from colossalai.core import global_context as gpc
-from colossalai.registry import GRADIENT_HANDLER
+from colossalai.legacy.registry import GRADIENT_HANDLER
 
 from ._base_gradient_handler import BaseGradientHandler
 from .utils import bucket_allreduce
diff --git a/colossalai/legacy/engine/gradient_handler/_zero_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_zero_gradient_handler.py
index 19fd1e97f..4ca7cd0b0 100644
--- a/colossalai/legacy/engine/gradient_handler/_zero_gradient_handler.py
+++ b/colossalai/legacy/engine/gradient_handler/_zero_gradient_handler.py
@@ -1,4 +1,4 @@
-from colossalai.registry import GRADIENT_HANDLER
+from colossalai.legacy.registry import GRADIENT_HANDLER
 
 from ._base_gradient_handler import BaseGradientHandler
 
diff --git a/colossalai/registry/__init__.py b/colossalai/legacy/registry/__init__.py
similarity index 100%
rename from colossalai/registry/__init__.py
rename to colossalai/legacy/registry/__init__.py
diff --git a/colossalai/registry/registry.py b/colossalai/legacy/registry/registry.py
similarity index 98%
rename from colossalai/registry/registry.py
rename to colossalai/legacy/registry/registry.py
index 8a4173f7a..50d6b74c5 100644
--- a/colossalai/registry/registry.py
+++ b/colossalai/legacy/registry/registry.py
@@ -6,7 +6,7 @@ from typing import List
 
 
 class Registry:
-    """This is a registry class used to register classes and modules so that a universal 
+    """This is a registry class used to register classes and modules so that a universal
     object builder can be enabled.
 
     Args:
@@ -42,7 +42,7 @@ class Registry:
         return module_class
 
     def get_module(self, module_name: str):
-        """Retrieves a module with name `module_name` and returns the module if it has 
+        """Retrieves a module with name `module_name` and returns the module if it has
         already been registered before.
 
         Args:
diff --git a/colossalai/legacy/trainer/hooks/_checkpoint_hook.py b/colossalai/legacy/trainer/hooks/_checkpoint_hook.py
index 7754ebcc3..6b150d291 100644
--- a/colossalai/legacy/trainer/hooks/_checkpoint_hook.py
+++ b/colossalai/legacy/trainer/hooks/_checkpoint_hook.py
@@ -2,9 +2,9 @@
 # -*- encoding: utf-8 -*-
 import torch
 
+from colossalai.legacy.registry import HOOKS
 from colossalai.legacy.trainer.hooks import BaseHook
 from colossalai.logging import get_dist_logger
-from colossalai.registry import HOOKS
 from colossalai.utils.checkpointing import save_checkpoint
 
 from ._lr_scheduler_hook import LRSchedulerHook
diff --git a/colossalai/legacy/trainer/hooks/_log_hook.py b/colossalai/legacy/trainer/hooks/_log_hook.py
index 1efc8be76..7d9ad19aa 100644
--- a/colossalai/legacy/trainer/hooks/_log_hook.py
+++ b/colossalai/legacy/trainer/hooks/_log_hook.py
@@ -7,9 +7,9 @@ from typing import List
 
 from colossalai.context import ParallelMode
 from colossalai.core import global_context as gpc
+from colossalai.legacy.registry import HOOKS
 from colossalai.legacy.trainer.hooks._metric_hook import ThroughputMetric
 from colossalai.logging import DistributedLogger
-from colossalai.registry import HOOKS
 from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0, report_memory_usage
 
 from ._base_hook import BaseHook
diff --git a/colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py b/colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py
index 0d19ab08a..6d60966da 100644
--- a/colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py
+++ b/colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py
@@ -1,6 +1,6 @@
 from torch import Tensor
 
-from colossalai.registry import HOOKS
+from colossalai.legacy.registry import HOOKS
 
 from ._metric_hook import LearningRateMetric, MetricHook
 
diff --git a/colossalai/legacy/trainer/hooks/_metric_hook.py b/colossalai/legacy/trainer/hooks/_metric_hook.py
index 96def4172..d0598c240 100644
--- a/colossalai/legacy/trainer/hooks/_metric_hook.py
+++ b/colossalai/legacy/trainer/hooks/_metric_hook.py
@@ -10,7 +10,7 @@ import torch.distributed as dist
 from colossalai.communication import all_reduce
 from colossalai.context import ParallelMode
 from colossalai.core import global_context as gpc
-from colossalai.registry import HOOKS
+from colossalai.legacy.registry import HOOKS
 from colossalai.utils import get_current_device, is_no_pp_or_last_stage
 
 from ._base_hook import BaseHook
@@ -356,7 +356,7 @@ class ThroughputMetric(Metric):
             self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA)
         else:
             self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
-                 gpc.get_world_size(ParallelMode.DATA)
+                gpc.get_world_size(ParallelMode.DATA)
             self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
 
         sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())
@@ -367,7 +367,7 @@ class ThroughputMetric(Metric):
             self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA)
         else:
             self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
-                 gpc.get_world_size(ParallelMode.DATA)
+                gpc.get_world_size(ParallelMode.DATA)
             self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
 
         sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())
diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py
index 406173a18..7b129009e 100644
--- a/colossalai/nn/layer/parallel_1d/layers.py
+++ b/colossalai/nn/layer/parallel_1d/layers.py
@@ -15,8 +15,8 @@ from colossalai.context import ParallelMode, seed
 from colossalai.core import global_context as gpc
 from colossalai.global_variables import tensor_parallel_env as env
 from colossalai.kernel import LayerNorm
+from colossalai.legacy.registry import LAYERS
 from colossalai.nn import init as init
-from colossalai.registry import LAYERS
 from colossalai.utils.checkpointing import (
     broadcast_state_dict,
     gather_tensor_parallel_state_dict,
diff --git a/colossalai/nn/layer/parallel_2d/layers.py b/colossalai/nn/layer/parallel_2d/layers.py
index f3a4d2bbb..1a01d5437 100644
--- a/colossalai/nn/layer/parallel_2d/layers.py
+++ b/colossalai/nn/layer/parallel_2d/layers.py
@@ -5,21 +5,30 @@ from typing import Callable
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
+from torch import Tensor
+from torch.nn import Parameter
+
 from colossalai.communication import broadcast
 from colossalai.context import ParallelMode, seed
 from colossalai.core import global_context as gpc
 from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.registry import LAYERS
 from colossalai.nn import init as init
-from colossalai.registry import LAYERS
 from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict, partition_tensor_parallel_state_dict
 from colossalai.utils.cuda import get_current_device
-from torch import Tensor
-from torch.nn import Parameter
 
 from ..base_layer import ParallelLayer
 from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
-from ._operation import (Matmul_AB_2D, Matmul_ABT_2D, add_bias_2d, all_gather_tensor_2d, classifier_2d, layernorm_2d,
-                         reduce_scatter_tensor_2d, split_batch_2d)
+from ._operation import (
+    Matmul_AB_2D,
+    Matmul_ABT_2D,
+    add_bias_2d,
+    all_gather_tensor_2d,
+    classifier_2d,
+    layernorm_2d,
+    reduce_scatter_tensor_2d,
+    split_batch_2d,
+)
 from ._utils import assert_summa_initialization, get_summa_dim_from_env
 
 
diff --git a/colossalai/nn/layer/parallel_2p5d/layers.py b/colossalai/nn/layer/parallel_2p5d/layers.py
index f849cbbe7..62c4292fd 100644
--- a/colossalai/nn/layer/parallel_2p5d/layers.py
+++ b/colossalai/nn/layer/parallel_2p5d/layers.py
@@ -5,22 +5,34 @@ from typing import Callable
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
+from torch import Tensor
+from torch.nn import Parameter
+
 from colossalai.communication import broadcast
 from colossalai.context import ParallelMode, seed
 from colossalai.core import global_context as gpc
 from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.registry import LAYERS
 from colossalai.nn import init as init
-from colossalai.registry import LAYERS
-from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict,
-                                            partition_tensor_parallel_state_dict)
+from colossalai.utils.checkpointing import (
+    broadcast_state_dict,
+    gather_tensor_parallel_state_dict,
+    partition_tensor_parallel_state_dict,
+)
 from colossalai.utils.cuda import get_current_device
-from torch import Tensor
-from torch.nn import Parameter
 
 from ..base_layer import ParallelLayer
 from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
-from ._operation import (Matmul_AB_2p5D, Matmul_ABT_2p5D, add_bias_2p5d, all_gather_tensor_2p5d, classifier_2p5d,
-                         layernorm_2p5d, reduce_scatter_tensor_2p5d, split_batch_2p5d)
+from ._operation import (
+    Matmul_AB_2p5D,
+    Matmul_ABT_2p5D,
+    add_bias_2p5d,
+    all_gather_tensor_2p5d,
+    classifier_2p5d,
+    layernorm_2p5d,
+    reduce_scatter_tensor_2p5d,
+    split_batch_2p5d,
+)
 from ._utils import assert_tesseract_initialization, get_tesseract_dim_dep_from_env
 
 
diff --git a/colossalai/nn/layer/parallel_3d/layers.py b/colossalai/nn/layer/parallel_3d/layers.py
index 99b0c3f8b..7d940aa27 100644
--- a/colossalai/nn/layer/parallel_3d/layers.py
+++ b/colossalai/nn/layer/parallel_3d/layers.py
@@ -13,9 +13,9 @@ from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP
 from colossalai.context import ParallelMode, seed
 from colossalai.core import global_context as gpc
 from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.registry import LAYERS
 from colossalai.nn import init as init
 from colossalai.nn.layer.base_layer import ParallelLayer
-from colossalai.registry import LAYERS
 from colossalai.utils.checkpointing import (
     broadcast_state_dict,
     gather_tensor_parallel_state_dict,
diff --git a/colossalai/nn/layer/parallel_sequence/layers.py b/colossalai/nn/layer/parallel_sequence/layers.py
index 0887f8389..4d0ff2e06 100644
--- a/colossalai/nn/layer/parallel_sequence/layers.py
+++ b/colossalai/nn/layer/parallel_sequence/layers.py
@@ -2,20 +2,20 @@
 # -*- encoding: utf-8 -*-
 
 import math
-import colossalai
 
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 from torch.nn import Parameter
 
+import colossalai
+from colossalai.context import seed
 from colossalai.context.parallel_mode import ParallelMode
 from colossalai.core import global_context as gpc
-from colossalai.nn.layer.parallel_sequence._operation import RingQK, RingAV
-from colossalai.registry import LAYERS
-from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
 from colossalai.kernel import FusedScaleMaskSoftmax
-from colossalai.context import seed
+from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
+from colossalai.legacy.registry import LAYERS
+from colossalai.nn.layer.parallel_sequence._operation import RingAV, RingQK
 
 
 @LAYERS.register_module
diff --git a/colossalai/nn/layer/vanilla/layers.py b/colossalai/nn/layer/vanilla/layers.py
index 225aed391..0e11fc4d0 100644
--- a/colossalai/nn/layer/vanilla/layers.py
+++ b/colossalai/nn/layer/vanilla/layers.py
@@ -8,8 +8,8 @@ from torch import nn as nn
 from torch.nn.parameter import Parameter
 
 from colossalai.context import seed
+from colossalai.legacy.registry import LAYERS
 from colossalai.nn import init as init
-from colossalai.registry import LAYERS
 from colossalai.utils.cuda import get_current_device
 
 from ..utils import to_2tuple
diff --git a/colossalai/nn/loss/loss_1d.py b/colossalai/nn/loss/loss_1d.py
index dd548c1d3..8c9483fcc 100644
--- a/colossalai/nn/loss/loss_1d.py
+++ b/colossalai/nn/loss/loss_1d.py
@@ -1,105 +1,106 @@
-import torch
-import torch.distributed as dist
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.registry import LOSSES
-from torch.cuda.amp import custom_bwd, custom_fwd
-from torch.nn.modules.loss import _Loss
-
-
-class _VocabParallelCrossEntropy1D(torch.autograd.Function):
-
-    @staticmethod
-    @custom_fwd(cast_inputs=torch.float32)
-    def forward(ctx, vocab_parallel_logits, targets, process_group):
-        if process_group is None:
-            process_group = gpc.get_group(ParallelMode.PARALLEL_1D)
-
-        # Maximum value along vocab dimension across all GPUs.
-        logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
-        torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group)
-        # Subtract the maximum value.
-        vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
-
-        # Get the partition's vocab indices
-        partition_vocab_size = vocab_parallel_logits.size()[-1]
-        rank = dist.get_rank(process_group)
-        vocab_start_index = partition_vocab_size * rank
-        vocab_end_index = vocab_start_index + partition_vocab_size
-
-        # Create a mask of valid vocab ids (1 means it needs to be masked).
-        target_mask = (targets < vocab_start_index) | (targets >= vocab_end_index)
-        masked_target = targets.clone() - vocab_start_index
-        masked_target[target_mask] = 0
-
-        # Get predicted-logits = logits[target].
-        # For Simplicity, we convert logits to a 2-D tensor with size
-        # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
-        logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
-        masked_target_1d = masked_target.view(-1)
-        arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
-        predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
-        predicted_logits_1d = predicted_logits_1d.clone().contiguous()
-        predicted_logits = predicted_logits_1d.view_as(targets)
-        predicted_logits[target_mask] = 0.0
-        # All reduce is needed to get the chunks from other GPUs.
-        torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group)
-
-        # Sum of exponential of logits along vocab dimension across all GPUs.
-        exp_logits = torch.exp(vocab_parallel_logits)
-        sum_exp_logits = exp_logits.sum(dim=-1)
-        torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group)
-
-        # Loss = log(sum(exp(logits))) - predicted-logit.
-        loss = torch.log(sum_exp_logits) - predicted_logits
-        # Store softmax, target-mask and masked-target for backward pass.
-        exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
-        ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
-        return loss
-
-    @staticmethod
-    @custom_bwd
-    def backward(ctx, grad_output):
-
-        # Retrieve tensors from the forward path.
-        softmax, target_mask, masked_target_1d = ctx.saved_tensors
-
-        # All the inputs have softmax as their gradient.
-        grad_input = softmax
-        # For simplicity, work with the 2D gradient.
-        partition_vocab_size = softmax.size()[-1]
-        grad_2d = grad_input.view(-1, partition_vocab_size)
-
-        # Add the gradient from matching classes.
-        arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
-        grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float())
-
-        # Finally elementwise multiplication with the output gradients.
-        grad_input.mul_(grad_output.unsqueeze(dim=-1))
-
-        return grad_input, None, None
-
-
-@LOSSES.register_module
-class VocabParallelCrossEntropyLoss1D(_Loss):
-    """Vocab parallel cross entropy loss for 1D parallelism.
-
-    Args:
-        reduction (bool, optional): whether to average the loss, defaults to True.
-    """
-
-    def __init__(self, reduction=True):
-        super().__init__()
-        self.reduction_mean = reduction
-
-    def forward(self, logits, targets, process_group=None):
-        """Calculate loss between logits and targets.
-
-        Args:
-            logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
-            targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
-        """
-        loss = _VocabParallelCrossEntropy1D.apply(logits, targets, process_group)
-        if self.reduction_mean:
-            loss = loss.mean()
-        return loss
+import torch
+import torch.distributed as dist
+from torch.cuda.amp import custom_bwd, custom_fwd
+from torch.nn.modules.loss import _Loss
+
+from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.legacy.registry import LOSSES
+
+
+class _VocabParallelCrossEntropy1D(torch.autograd.Function):
+
+    @staticmethod
+    @custom_fwd(cast_inputs=torch.float32)
+    def forward(ctx, vocab_parallel_logits, targets, process_group):
+        if process_group is None:
+            process_group = gpc.get_group(ParallelMode.PARALLEL_1D)
+
+        # Maximum value along vocab dimension across all GPUs.
+        logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
+        torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group)
+        # Subtract the maximum value.
+        vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
+
+        # Get the partition's vocab indices
+        partition_vocab_size = vocab_parallel_logits.size()[-1]
+        rank = dist.get_rank(process_group)
+        vocab_start_index = partition_vocab_size * rank
+        vocab_end_index = vocab_start_index + partition_vocab_size
+
+        # Create a mask of valid vocab ids (1 means it needs to be masked).
+        target_mask = (targets < vocab_start_index) | (targets >= vocab_end_index)
+        masked_target = targets.clone() - vocab_start_index
+        masked_target[target_mask] = 0
+
+        # Get predicted-logits = logits[target].
+        # For Simplicity, we convert logits to a 2-D tensor with size
+        # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
+        logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
+        masked_target_1d = masked_target.view(-1)
+        arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
+        predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
+        predicted_logits_1d = predicted_logits_1d.clone().contiguous()
+        predicted_logits = predicted_logits_1d.view_as(targets)
+        predicted_logits[target_mask] = 0.0
+        # All reduce is needed to get the chunks from other GPUs.
+        torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group)
+
+        # Sum of exponential of logits along vocab dimension across all GPUs.
+        exp_logits = torch.exp(vocab_parallel_logits)
+        sum_exp_logits = exp_logits.sum(dim=-1)
+        torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group)
+
+        # Loss = log(sum(exp(logits))) - predicted-logit.
+        loss = torch.log(sum_exp_logits) - predicted_logits
+        # Store softmax, target-mask and masked-target for backward pass.
+        exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
+        ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
+        return loss
+
+    @staticmethod
+    @custom_bwd
+    def backward(ctx, grad_output):
+
+        # Retrieve tensors from the forward path.
+        softmax, target_mask, masked_target_1d = ctx.saved_tensors
+
+        # All the inputs have softmax as their gradient.
+        grad_input = softmax
+        # For simplicity, work with the 2D gradient.
+        partition_vocab_size = softmax.size()[-1]
+        grad_2d = grad_input.view(-1, partition_vocab_size)
+
+        # Add the gradient from matching classes.
+        arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
+        grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float())
+
+        # Finally elementwise multiplication with the output gradients.
+        grad_input.mul_(grad_output.unsqueeze(dim=-1))
+
+        return grad_input, None, None
+
+
+@LOSSES.register_module
+class VocabParallelCrossEntropyLoss1D(_Loss):
+    """Vocab parallel cross entropy loss for 1D parallelism.
+
+    Args:
+        reduction (bool, optional): whether to average the loss, defaults to True.
+    """
+
+    def __init__(self, reduction=True):
+        super().__init__()
+        self.reduction_mean = reduction
+
+    def forward(self, logits, targets, process_group=None):
+        """Calculate loss between logits and targets.
+
+        Args:
+            logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
+            targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
+        """
+        loss = _VocabParallelCrossEntropy1D.apply(logits, targets, process_group)
+        if self.reduction_mean:
+            loss = loss.mean()
+        return loss
diff --git a/colossalai/nn/loss/loss_2d.py b/colossalai/nn/loss/loss_2d.py
index 7da8b2d69..6db40c0f3 100644
--- a/colossalai/nn/loss/loss_2d.py
+++ b/colossalai/nn/loss/loss_2d.py
@@ -1,15 +1,16 @@
 import torch
 import torch.distributed as dist
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
-from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization
-from colossalai.registry import LOSSES
-from colossalai.utils import get_current_device
 from torch.cuda.amp import custom_bwd, custom_fwd
 from torch.nn.functional import cross_entropy
 from torch.nn.modules.loss import _Loss
 
+from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.legacy.registry import LOSSES
+from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
+from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization
+from colossalai.utils import get_current_device
+
 
 @LOSSES.register_module
 class CrossEntropyLoss2D(_Loss):
diff --git a/colossalai/nn/loss/loss_2p5d.py b/colossalai/nn/loss/loss_2p5d.py
index 63dc4f33a..9c78a1ef0 100644
--- a/colossalai/nn/loss/loss_2p5d.py
+++ b/colossalai/nn/loss/loss_2p5d.py
@@ -1,15 +1,16 @@
 import torch
 import torch.distributed as dist
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
-from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
-from colossalai.registry import LOSSES
-from colossalai.utils import get_current_device
 from torch.cuda.amp import custom_bwd, custom_fwd
 from torch.nn.functional import cross_entropy
 from torch.nn.modules.loss import _Loss
 
+from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.legacy.registry import LOSSES
+from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
+from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
+from colossalai.utils import get_current_device
+
 
 @LOSSES.register_module
 class CrossEntropyLoss2p5D(_Loss):
diff --git a/colossalai/nn/loss/loss_3d.py b/colossalai/nn/loss/loss_3d.py
index f27d57ad6..5c0f26640 100644
--- a/colossalai/nn/loss/loss_3d.py
+++ b/colossalai/nn/loss/loss_3d.py
@@ -1,15 +1,16 @@
 import torch
 import torch.distributed as dist
-from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D
-from colossalai.core import global_context as gpc
-from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
-from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
-from colossalai.registry import LOSSES
-from colossalai.utils import get_current_device
 from torch.cuda.amp import custom_bwd, custom_fwd
 from torch.nn.functional import cross_entropy
 from torch.nn.modules.loss import _Loss
 
+from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
+from colossalai.core import global_context as gpc
+from colossalai.legacy.registry import LOSSES
+from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
+from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
+from colossalai.utils import get_current_device
+
 
 @LOSSES.register_module
 class CrossEntropyLoss3D(_Loss):
diff --git a/colossalai/nn/loss/loss_moe.py b/colossalai/nn/loss/loss_moe.py
index a8b18a3e3..40cea788c 100644
--- a/colossalai/nn/loss/loss_moe.py
+++ b/colossalai/nn/loss/loss_moe.py
@@ -1,80 +1,81 @@
-import torch.nn as nn
-from colossalai.registry import LOSSES
-from torch.nn.modules.loss import _Loss
-from colossalai.context.moe_context import MOE_CONTEXT
-
-
-@LOSSES.register_module
-class MoeCrossEntropyLoss(_Loss):
-    r"""torch.nn.CrossEntropyLoss added with auxiliary loss.
-
-    Args:
-        input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
-        target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
-        aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01.
-
-    The ``args`` and ``kwargs`` should include parameters below:
-    ::
-
-        weight (Tensor, optional)
-        size_average (bool, optional)
-        ignore_index (int, optional)
-        reduce (bool, optional)
-        reduction (str, optional)
-        label_smoothing (float, optional)
-
-    More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
-    `Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
-    """
-
-    def __init__(self, aux_weight: float = 0.01, *args, **kwargs):
-        super().__init__()
-        self.loss = nn.CrossEntropyLoss(*args, **kwargs)
-        self.aux_weight = aux_weight
-
-    def forward(self, *args):
-        """
-        The ``args`` should at least include parameters below:
-        ::
-
-            input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
-            target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
-
-        More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
-        `Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
-        """
-        main_loss = self.loss(*args)
-        aux_loss = MOE_CONTEXT.get_loss()
-        return main_loss + self.aux_weight * aux_loss
-
-
-@LOSSES.register_module
-class MoeLoss(_Loss):
-    """A wrapper class for any loss module to add with auxiliary loss.
-
-    Args:
-        aux_weight (float): Weight of auxiliary loss in total loss.
-        loss_fn (``Callable``): Loss function.
-        args (list): Args in loss function.
-        kwargs (dict): Kwargs in loss function
-    """
-
-    def __init__(self, aux_weight: float, loss_fn, *args, **kwargs):
-        super().__init__()
-        self.loss_fn = loss_fn(*args, **kwargs)
-        self.aux_weight = aux_weight
-
-    def forward(self, *args, **kwargs):
-        """
-        The ``args`` and ``kwargs`` should at least include parameters below:
-        ::
-
-            input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
-            target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
-
-        Note:
-            The ``args`` and ``kwargs`` may include different parameters varying with different loss function.
-        """
-        main_loss = self.loss_fn(*args, **kwargs)
-        aux_loss = MOE_CONTEXT.get_loss()
-        return main_loss + self.aux_weight * aux_loss
+import torch.nn as nn
+from torch.nn.modules.loss import _Loss
+
+from colossalai.context.moe_context import MOE_CONTEXT
+from colossalai.legacy.registry import LOSSES
+
+
+@LOSSES.register_module
+class MoeCrossEntropyLoss(_Loss):
+    r"""torch.nn.CrossEntropyLoss added with auxiliary loss.
+
+    Args:
+        input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
+        target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
+        aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01.
+
+    The ``args`` and ``kwargs`` should include parameters below:
+    ::
+
+        weight (Tensor, optional)
+        size_average (bool, optional)
+        ignore_index (int, optional)
+        reduce (bool, optional)
+        reduction (str, optional)
+        label_smoothing (float, optional)
+
+    More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
+    `Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
+    """
+
+    def __init__(self, aux_weight: float = 0.01, *args, **kwargs):
+        super().__init__()
+        self.loss = nn.CrossEntropyLoss(*args, **kwargs)
+        self.aux_weight = aux_weight
+
+    def forward(self, *args):
+        """
+        The ``args`` should at least include parameters below:
+        ::
+
+            input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
+            target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
+
+        More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
+        `Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
+        """
+        main_loss = self.loss(*args)
+        aux_loss = MOE_CONTEXT.get_loss()
+        return main_loss + self.aux_weight * aux_loss
+
+
+@LOSSES.register_module
+class MoeLoss(_Loss):
+    """A wrapper class for any loss module to add with auxiliary loss.
+
+    Args:
+        aux_weight (float): Weight of auxiliary loss in total loss.
+        loss_fn (``Callable``): Loss function.
+        args (list): Args in loss function.
+        kwargs (dict): Kwargs in loss function
+    """
+
+    def __init__(self, aux_weight: float, loss_fn, *args, **kwargs):
+        super().__init__()
+        self.loss_fn = loss_fn(*args, **kwargs)
+        self.aux_weight = aux_weight
+
+    def forward(self, *args, **kwargs):
+        """
+        The ``args`` and ``kwargs`` should at least include parameters below:
+        ::
+
+            input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
+            target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
+
+        Note:
+            The ``args`` and ``kwargs`` may include different parameters varying with different loss function.
+        """
+        main_loss = self.loss_fn(*args, **kwargs)
+        aux_loss = MOE_CONTEXT.get_loss()
+        return main_loss + self.aux_weight * aux_loss
diff --git a/colossalai/nn/lr_scheduler/cosine.py b/colossalai/nn/lr_scheduler/cosine.py
index aab523bef..0010435c2 100644
--- a/colossalai/nn/lr_scheduler/cosine.py
+++ b/colossalai/nn/lr_scheduler/cosine.py
@@ -1,6 +1,7 @@
 from torch.optim.lr_scheduler import CosineAnnealingLR as _CosineAnnealingLR
 
-from colossalai.registry import LR_SCHEDULERS
+from colossalai.legacy.registry import LR_SCHEDULERS
+
 from .delayed import DelayerScheduler, WarmupDelayerScheduler, WarmupScheduler
 
 
diff --git a/colossalai/nn/lr_scheduler/linear.py b/colossalai/nn/lr_scheduler/linear.py
index 556938b8a..251779647 100644
--- a/colossalai/nn/lr_scheduler/linear.py
+++ b/colossalai/nn/lr_scheduler/linear.py
@@ -1,6 +1,6 @@
 from torch.optim.lr_scheduler import _LRScheduler
 
-from colossalai.registry import LR_SCHEDULERS
+from colossalai.legacy.registry import LR_SCHEDULERS
 
 
 @LR_SCHEDULERS.register_module
diff --git a/colossalai/nn/lr_scheduler/multistep.py b/colossalai/nn/lr_scheduler/multistep.py
index 29531a9e3..4f18b49fc 100644
--- a/colossalai/nn/lr_scheduler/multistep.py
+++ b/colossalai/nn/lr_scheduler/multistep.py
@@ -2,7 +2,8 @@ from typing import List
 
 from torch.optim.lr_scheduler import MultiStepLR as _MultiStepLR
 
-from colossalai.registry import LR_SCHEDULERS
+from colossalai.legacy.registry import LR_SCHEDULERS
+
 from .delayed import WarmupScheduler
 
 
diff --git a/colossalai/nn/lr_scheduler/onecycle.py b/colossalai/nn/lr_scheduler/onecycle.py
index 8007fd360..20e9aaec6 100644
--- a/colossalai/nn/lr_scheduler/onecycle.py
+++ b/colossalai/nn/lr_scheduler/onecycle.py
@@ -1,6 +1,6 @@
 from torch.optim.lr_scheduler import OneCycleLR as _OneCycleLR
 
-from colossalai.registry import LR_SCHEDULERS
+from colossalai.legacy.registry import LR_SCHEDULERS
 
 
 @LR_SCHEDULERS.register_module
diff --git a/colossalai/nn/lr_scheduler/poly.py b/colossalai/nn/lr_scheduler/poly.py
index 16352bc51..a98506423 100644
--- a/colossalai/nn/lr_scheduler/poly.py
+++ b/colossalai/nn/lr_scheduler/poly.py
@@ -1,6 +1,7 @@
 from torch.optim.lr_scheduler import _LRScheduler
 
-from colossalai.registry import LR_SCHEDULERS
+from colossalai.legacy.registry import LR_SCHEDULERS
+
 from .delayed import WarmupScheduler
 
 
diff --git a/colossalai/nn/lr_scheduler/torch.py b/colossalai/nn/lr_scheduler/torch.py
index 05d2a49c1..09f5d4585 100644
--- a/colossalai/nn/lr_scheduler/torch.py
+++ b/colossalai/nn/lr_scheduler/torch.py
@@ -1,9 +1,9 @@
+from torch.optim.lr_scheduler import ExponentialLR as _ExponentialLR
 from torch.optim.lr_scheduler import LambdaLR as _LambdaLR
 from torch.optim.lr_scheduler import MultiplicativeLR as _MultiplicativeLR
 from torch.optim.lr_scheduler import StepLR as _StepLR
-from torch.optim.lr_scheduler import ExponentialLR as _ExponentialLR
 
-from colossalai.registry import LR_SCHEDULERS
+from colossalai.legacy.registry import LR_SCHEDULERS
 
 
 @LR_SCHEDULERS.register_module
diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py
index 3a6d37103..210400a21 100644
--- a/colossalai/nn/optimizer/cpu_adam.py
+++ b/colossalai/nn/optimizer/cpu_adam.py
@@ -4,7 +4,7 @@ from typing import Optional
 import torch
 
 from colossalai.kernel.op_builder import CPUAdamBuilder
-from colossalai.registry import OPTIMIZERS
+from colossalai.legacy.registry import OPTIMIZERS
 
 from .nvme_optimizer import NVMeOptimizer
 
diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py
index 82a6250f1..0d13873cd 100644
--- a/colossalai/nn/optimizer/fused_adam.py
+++ b/colossalai/nn/optimizer/fused_adam.py
@@ -8,7 +8,7 @@ Licensed under the MIT License.
 '''
 import torch
 
-from colossalai.registry import OPTIMIZERS
+from colossalai.legacy.registry import OPTIMIZERS
 from colossalai.utils import multi_tensor_applier
 
 
diff --git a/colossalai/nn/optimizer/fused_lamb.py b/colossalai/nn/optimizer/fused_lamb.py
index 72520064e..48cc097c7 100644
--- a/colossalai/nn/optimizer/fused_lamb.py
+++ b/colossalai/nn/optimizer/fused_lamb.py
@@ -1,7 +1,7 @@
 # modified from https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fused_lamb.py
 import torch
 
-from colossalai.registry import OPTIMIZERS
+from colossalai.legacy.registry import OPTIMIZERS
 from colossalai.utils import multi_tensor_applier
 
 
diff --git a/colossalai/nn/optimizer/fused_sgd.py b/colossalai/nn/optimizer/fused_sgd.py
index 468713b22..0e8d3fc10 100644
--- a/colossalai/nn/optimizer/fused_sgd.py
+++ b/colossalai/nn/optimizer/fused_sgd.py
@@ -2,7 +2,7 @@
 import torch
 from torch.optim.optimizer import Optimizer, required
 
-from colossalai.registry import OPTIMIZERS
+from colossalai.legacy.registry import OPTIMIZERS
 from colossalai.utils import multi_tensor_applier
 
 
diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py
index 84903ac36..7aa0ced18 100644
--- a/colossalai/nn/optimizer/hybrid_adam.py
+++ b/colossalai/nn/optimizer/hybrid_adam.py
@@ -4,7 +4,7 @@ import torch
 from torch.optim import Adam
 
 from colossalai.kernel.op_builder import FusedOptimBuilder
-from colossalai.registry import OPTIMIZERS
+from colossalai.legacy.registry import OPTIMIZERS
 from colossalai.utils import multi_tensor_applier
 
 from .cpu_adam import CPUAdam
diff --git a/colossalai/nn/optimizer/lamb.py b/colossalai/nn/optimizer/lamb.py
index 399ad39b6..769c11f62 100644
--- a/colossalai/nn/optimizer/lamb.py
+++ b/colossalai/nn/optimizer/lamb.py
@@ -5,7 +5,7 @@ Adapted from the pytorch-lamb library at https://github.com/cybertronai/pytorch-
 import torch
 from torch.optim import Optimizer
 
-from colossalai.registry import OPTIMIZERS
+from colossalai.legacy.registry import OPTIMIZERS
 
 
 @OPTIMIZERS.register_module
diff --git a/colossalai/nn/optimizer/lars.py b/colossalai/nn/optimizer/lars.py
index 212f66671..9dbb83b84 100644
--- a/colossalai/nn/optimizer/lars.py
+++ b/colossalai/nn/optimizer/lars.py
@@ -5,7 +5,7 @@ from typing import Iterable
 import torch
 from torch.optim import Optimizer
 
-from colossalai.registry import OPTIMIZERS
+from colossalai.legacy.registry import OPTIMIZERS
 
 
 @OPTIMIZERS.register_module
@@ -22,28 +22,24 @@ class Lars(Optimizer):
         weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
     """
 
-    def __init__(
-            self,
-            params: Iterable[torch.nn.Parameter],
-            lr=1e-3,
-            momentum=0,
-            eeta=1e-3,
-            weight_decay=0,
-            epsilon=0.0
-    ) -> None:
+    def __init__(self,
+                 params: Iterable[torch.nn.Parameter],
+                 lr=1e-3,
+                 momentum=0,
+                 eeta=1e-3,
+                 weight_decay=0,
+                 epsilon=0.0) -> None:
         if not isinstance(lr, float) or lr < 0.0:
             raise ValueError("Invalid learning rate: {}".format(lr))
         if momentum < 0.0:
             raise ValueError("Invalid momentum value: {}".format(momentum))
         if weight_decay < 0.0:
-            raise ValueError(
-                "Invalid weight_decay value: {}".format(weight_decay))
+            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
         if eeta <= 0 or eeta > 1:
             raise ValueError("Invalid eeta value: {}".format(eeta))
         if epsilon < 0:
             raise ValueError("Invalid epsilon value: {}".format(epsilon))
-        defaults = dict(lr=lr, momentum=momentum,
-                        weight_decay=weight_decay, eeta=eeta, epsilon=epsilon, lars=True)
+        defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay, eeta=eeta, epsilon=epsilon, lars=True)
 
         super().__init__(params, defaults)
 
@@ -76,11 +72,9 @@ class Lars(Optimizer):
                 if lars:
                     w_norm = torch.norm(p)
                     g_norm = torch.norm(p.grad)
-                    trust_ratio = torch.where(
-                        w_norm > 0 and g_norm > 0,
-                        eeta * w_norm / (g_norm + weight_decay * w_norm + eps),
-                        torch.ones_like(w_norm)
-                    )
+                    trust_ratio = torch.where(w_norm > 0 and g_norm > 0,
+                                              eeta * w_norm / (g_norm + weight_decay * w_norm + eps),
+                                              torch.ones_like(w_norm))
                     trust_ratio.clamp_(0.0, 50)
                     scaled_lr *= trust_ratio.item()
                     if weight_decay != 0:
@@ -90,8 +84,7 @@ class Lars(Optimizer):
                 if momentum != 0:
                     param_state = self.state[p]
                     if 'momentum_buffer' not in param_state:
-                        buf = param_state['momentum_buffer'] = torch.clone(
-                            decayed_grad).detach()
+                        buf = param_state['momentum_buffer'] = torch.clone(decayed_grad).detach()
                     else:
                         buf = param_state['momentum_buffer']
                         buf.mul_(momentum).add_(decayed_grad)
diff --git a/colossalai/utils/data_sampler/data_parallel_sampler.py b/colossalai/utils/data_sampler/data_parallel_sampler.py
index 2318e07a7..4ca7bce7b 100644
--- a/colossalai/utils/data_sampler/data_parallel_sampler.py
+++ b/colossalai/utils/data_sampler/data_parallel_sampler.py
@@ -4,15 +4,15 @@
 
 import math
 import random
-import numpy as np
-from typing import TypeVar, Iterator
+from typing import Iterator, TypeVar
 
+import numpy as np
 import torch
-from torch.utils.data import Sampler, Dataset, DataLoader
+from torch.utils.data import DataLoader, Dataset, Sampler
 
 from colossalai.context.parallel_mode import ParallelMode
 from colossalai.core import global_context as gpc
-from colossalai.registry import DATA_SAMPLERS
+from colossalai.legacy.registry import DATA_SAMPLERS
 
 T_co = TypeVar('T_co', covariant=True)
 
@@ -30,11 +30,7 @@ class DataParallelSampler(Sampler):
             the batch size, then the last batch will be smaller, defaults to False.
     """
 
-    def __init__(self,
-                 dataset: Dataset,
-                 shuffle: bool = False,
-                 seed: int = 0,
-                 drop_last: bool = False) -> None:
+    def __init__(self, dataset: Dataset, shuffle: bool = False, seed: int = 0, drop_last: bool = False) -> None:
         self.dataset = dataset
         self.num_replicas = gpc.get_world_size(ParallelMode.DATA)
         self.rank = gpc.get_local_rank(ParallelMode.DATA)
@@ -54,8 +50,7 @@ class DataParallelSampler(Sampler):
                 self.num_replicas  # type: ignore[arg-type]
             )
         else:
-            self.num_samples = math.ceil(
-                len(self.dataset) / self.num_replicas)  # type: ignore[arg-type]
+            self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)    # type: ignore[arg-type]
         self.total_size = self.num_samples * self.num_replicas
         self.shuffle = shuffle
         self.seed = seed
@@ -72,7 +67,7 @@ class DataParallelSampler(Sampler):
             # set_epoch manually
             self.epoch += 1
         else:
-            indices = list(range(len(self.dataset)))  # type: ignore[arg-type]
+            indices = list(range(len(self.dataset)))    # type: ignore[arg-type]
 
         if not self.drop_last:
             # add extra samples to make it evenly divisible
@@ -80,8 +75,7 @@ class DataParallelSampler(Sampler):
             if padding_size <= len(indices):
                 indices += indices[:padding_size]
             else:
-                indices += (indices * math.ceil(padding_size /
-                            len(indices)))[:padding_size]
+                indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
         else:
             # remove tail of data to make it evenly divisible.
             indices = indices[:self.total_size]
@@ -109,8 +103,8 @@ class DataParallelSampler(Sampler):
 
 def get_dataloader(dataset,
                    shuffle=False,
-                   seed=1024, 
-                   add_sampler=True, 
+                   seed=1024,
+                   add_sampler=True,
                    drop_last=False,
                    pin_memory=False,
                    num_workers=0,
diff --git a/colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py b/colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py
index 8f8fec649..d68a9dc64 100644
--- a/colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py
+++ b/colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py
@@ -1,6 +1,6 @@
 import torch
 
-from colossalai.registry import OPHOOKS
+from colossalai.legacy.registry import OPHOOKS
 
 from . import BaseOpHook
 
diff --git a/colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py b/colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py
index a2a62fb97..6b76a2116 100644
--- a/colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py
+++ b/colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py
@@ -1,6 +1,6 @@
 import torch
 
-from colossalai.registry import OPHOOKS
+from colossalai.legacy.registry import OPHOOKS
 
 from . import BaseOpHook
 
diff --git a/colossalai/zero/legacy/sharded_model/zero_hook.py b/colossalai/zero/legacy/sharded_model/zero_hook.py
index 50f4bdfc7..1815bee3a 100644
--- a/colossalai/zero/legacy/sharded_model/zero_hook.py
+++ b/colossalai/zero/legacy/sharded_model/zero_hook.py
@@ -3,8 +3,8 @@ from typing import Optional
 import torch
 import torch.distributed as dist
 
+from colossalai.legacy.registry import OPHOOKS
 from colossalai.logging import get_dist_logger
-from colossalai.registry import OPHOOKS
 from colossalai.utils import get_current_device
 from colossalai.zero.gemini.memory_tracer import MemStatsCollector
 from colossalai.zero.legacy.gemini.ophooks import BaseOpHook
diff --git a/docs/source/en/advanced_tutorials/add_your_parallel.md b/docs/source/en/advanced_tutorials/add_your_parallel.md
index cda49af47..384221596 100644
--- a/docs/source/en/advanced_tutorials/add_your_parallel.md
+++ b/docs/source/en/advanced_tutorials/add_your_parallel.md
@@ -98,7 +98,7 @@ parallel gradient handler is added to the engine automatically if data parallel
 gradient handler like below:
 
 ```python
-from colossalai.registry import GRADIENT_HANDLER
+from colossalai.legacy.registry import GRADIENT_HANDLER
 from colossalai.legacy.engine import BaseGradientHandler
 
 @GRADIENT_HANDLER.register_module
diff --git a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
index 98c16e922..5aa806c64 100644
--- a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
+++ b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
@@ -36,7 +36,7 @@ import torch
 import torch.nn as nn
 from colossalai import nn as col_nn
 from colossalai.amp import AMP_TYPE
-from colossalai.builder.pipeline import partition_uniform
+from colossalai.legacy.builder.pipeline import partition_uniform
 from colossalai.context.parallel_mode import ParallelMode
 from colossalai.core import global_context as gpc
 from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
diff --git a/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md b/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md
index 370931d87..6dbe33800 100644
--- a/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md
+++ b/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md
@@ -34,7 +34,7 @@ import colossalai
 import colossalai.nn as col_nn
 import torch
 import torch.nn as nn
-from colossalai.builder import build_pipeline_model
+from colossalai.legacy.builder import build_pipeline_model
 from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
                                         PipelineSchedule)
 from colossalai.logging import disable_existing_loggers, get_dist_logger
@@ -51,17 +51,17 @@ from torchvision.datasets import CIFAR10
 
 Generally, we provide 3 ways to build a pipelined model:
 
-1. `colossalai.builder.build_pipeline_model_from_cfg`
-2. `colossalai.builder.build_pipeline_model`
+1. `colossalai.legacy.builder.build_pipeline_model_from_cfg`
+2. `colossalai.legacy.builder.build_pipeline_model`
 3. Split the model by stages by yourself
 
 When your memory can fit the model, you can use the first two methods to build your model, otherwise you must split the model by yourself. The first two methods first build the whole model on CPU, then split the model, and finally you can just move the corresponding part of model to GPU.
 
-`colossalai.builder.build_pipeline_model_from_cfg()` receives a config file of model, and it can split the model uniformly (by layer) or balanced (by parameter size).
+`colossalai.legacy.builder.build_pipeline_model_from_cfg()` receives a config file of model, and it can split the model uniformly (by layer) or balanced (by parameter size).
 
-If you are familiar with `PyTorch`, you can use  `colossalai.builder.build_pipeline_model()` which receives a `torch.nn.Sequential` model and split it by layer uniformly.
+If you are familiar with `PyTorch`, you can use  `colossalai.legacy.builder.build_pipeline_model()` which receives a `torch.nn.Sequential` model and split it by layer uniformly.
 
-In this tutorial, we will modify [TIMM/ViT](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to `torch.nn.Sequential` and then use `colossalai.builder.build_pipeline_model()` to build the pipelined model.
+In this tutorial, we will modify [TIMM/ViT](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to `torch.nn.Sequential` and then use `colossalai.legacy.builder.build_pipeline_model()` to build the pipelined model.
 
 When the data is **one** `Tensor`, you can use the positional argument in `forward()` of your model to get the data tensor. For the first stage of pipeline, the first positional argument of `forward()` is the data tensor loaded from data loader. For other stages, the first positional argument of `forward()` is the output tensor from the previous stage. Note that if the stage is not the last stage, the return of `forward()` must be a `Tensor`.
 
diff --git a/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md
index fc1101c5a..22022639c 100644
--- a/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md
+++ b/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md
@@ -273,8 +273,8 @@ SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1  # add 1 for cls token
 
 ### Build pipeline model (`/hybrid_parallel/model/vit.py`)
 Colossal-AI provides two methods to build a pipeline model from the existing model.
-- `colossalai.builder.build_pipeline_model_from_cfg`
-- `colossalai.builder.build_pipeline_model`
+- `colossalai.legacy.builder.build_pipeline_model_from_cfg`
+- `colossalai.legacy.builder.build_pipeline_model`
 
 Besides, you can also build a pipeline model from scratch with Colossal-AI.
 ```python
@@ -284,11 +284,11 @@ from typing import Callable
 import inspect
 import torch
 from colossalai import nn as col_nn
-from colossalai.registry import LAYERS, MODELS
+from colossalai.legacy.registry import LAYERS, MODELS
 from colossalai.logging import get_dist_logger
 from colossalai.core import global_context as gpc
 from colossalai.context import ParallelMode
-from colossalai.builder.pipeline import partition_uniform
+from colossalai.legacy.builder.pipeline import partition_uniform
 from torch import dtype, nn
 from model_zoo.vit.vit import ViTBlock, ViTEmbedding, ViTHead
 
diff --git a/docs/source/en/features/gradient_handler.md b/docs/source/en/features/gradient_handler.md
index 14ced32b8..66e5e3a9d 100644
--- a/docs/source/en/features/gradient_handler.md
+++ b/docs/source/en/features/gradient_handler.md
@@ -28,7 +28,7 @@ To implement a customized gradient handler, you need to follow these steps.
 3. implement `handle_gradient` method.
 
 ```python
-from colossalai.registry import GRADIENT_HANDLER
+from colossalai.legacy.registry import GRADIENT_HANDLER
 from colossalai.legacy.engine.gradient_handler import BaseGradientHandler
 
 
diff --git a/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md b/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md
index abfe058c6..c4b0f6557 100644
--- a/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md
+++ b/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md
@@ -87,7 +87,7 @@ Colossal-AI 为用户提供了一个全局 context,使他们能够轻松地管
 你可以添加你自己的梯度 handler,如下所示:
 
 ```python
-from colossalai.registry import GRADIENT_HANDLER
+from colossalai.legacy.registry import GRADIENT_HANDLER
 from colossalai.legacy.engine import BaseGradientHandler
 
 @GRADIENT_HANDLER.register_module
diff --git a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
index 84b48165b..9cfbf5873 100644
--- a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
+++ b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
@@ -36,7 +36,7 @@ import torch
 import torch.nn as nn
 from colossalai import nn as col_nn
 from colossalai.amp import AMP_TYPE
-from colossalai.builder.pipeline import partition_uniform
+from colossalai.legacy.builder.pipeline import partition_uniform
 from colossalai.context.parallel_mode import ParallelMode
 from colossalai.core import global_context as gpc
 from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
diff --git a/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md
index 1ac01c207..5ef863dcd 100644
--- a/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md
+++ b/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md
@@ -32,7 +32,7 @@ import colossalai
 import colossalai.nn as col_nn
 import torch
 import torch.nn as nn
-from colossalai.builder import build_pipeline_model
+from colossalai.legacy.builder import build_pipeline_model
 from colossalai.legacy.engine.schedule import (InterleavedPipelineSchedule,
                                         PipelineSchedule)
 from colossalai.logging import disable_existing_loggers, get_dist_logger
@@ -48,17 +48,17 @@ from torchvision.datasets import CIFAR10
 
 总的来说, 我们提供3种方法来建立一个流水并行的模型:
 
-1. `colossalai.builder.build_pipeline_model_from_cfg`
-2. `colossalai.builder.build_pipeline_model`
+1. `colossalai.legacy.builder.build_pipeline_model_from_cfg`
+2. `colossalai.legacy.builder.build_pipeline_model`
 3. 自己按阶段拆分模型
 
 当你的内存能够容纳模型时,你可以使用前两种方法来建立你的模型,否则你必须自己分割模型。前两种方法首先在 CPU 上建立整个模型,然后分割模型,最后你可以直接把模型的相应部分移到 GPU 上。
 
-`colossalai.builder.build_pipeline_model_from_cfg()` 接收一个模型的配置文件,它可以均匀地(按层)或平衡地(按参数大小)分割模型。
+`colossalai.legacy.builder.build_pipeline_model_from_cfg()` 接收一个模型的配置文件,它可以均匀地(按层)或平衡地(按参数大小)分割模型。
 
-如果你熟悉 `PyTorch`, 你可以使用 `colossalai.builder.build_pipeline_model()` 它接收一个 `torch.nn.Sequential` 模型并按层均匀分割。
+如果你熟悉 `PyTorch`, 你可以使用 `colossalai.legacy.builder.build_pipeline_model()` 它接收一个 `torch.nn.Sequential` 模型并按层均匀分割。
 
-在本教程中,我们将修改 [TIMM/ViT](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to `torch.nn.Sequential`,然后使用 `colossalai.builder.build_pipeline_model()` 来建立流水线模型。
+在本教程中,我们将修改 [TIMM/ViT](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to `torch.nn.Sequential`,然后使用 `colossalai.legacy.builder.build_pipeline_model()` 来建立流水线模型。
 
 当数据是 **一个** `Tensor`, 你可以使用你的模型 `forward()` 中的位置参数来获得数据张量。对于流水线的第一阶段,`forward()` 的第一个位置参数是从数据加载器加载的数据张量。对于其他阶段,`forward()` 的第一个位置参数是上一阶段的输出张量。注意,如果该阶段不是最后一个阶段,则 `forward()` 的返回必须是一个 `Tensor`。
 
diff --git a/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md
index 650bab105..803882a5a 100644
--- a/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md
+++ b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md
@@ -256,8 +256,8 @@ SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1  # add 1 for cls token
 
 ### 构建流水线模型 (`/hybrid_parallel/model/vit.py`)
 Colossal-AI 提供了两种从现有模型构建流水线模型的方法。
-- `colossalai.builder.build_pipeline_model_from_cfg`
-- `colossalai.builder.build_pipeline_model`
+- `colossalai.legacy.builder.build_pipeline_model_from_cfg`
+- `colossalai.legacy.builder.build_pipeline_model`
 
 此外,您还可以使用 Colossal-AI 从头开始构建流水线模型。
 ```python
@@ -266,11 +266,11 @@ from typing import Callable
 import inspect
 import torch
 from colossalai import nn as col_nn
-from colossalai.registry import LAYERS, MODELS
+from colossalai.legacy.registry import LAYERS, MODELS
 from colossalai.logging import get_dist_logger
 from colossalai.core import global_context as gpc
 from colossalai.context import ParallelMode
-from colossalai.builder.pipeline import partition_uniform
+from colossalai.legacy.builder.pipeline import partition_uniform
 from torch import dtype, nn
 from model_zoo.vit.vit import ViTBlock, ViTEmbedding, ViTHead
 @MODELS.register_module
diff --git a/docs/source/zh-Hans/features/gradient_handler.md b/docs/source/zh-Hans/features/gradient_handler.md
index b08dd6806..3b1140409 100644
--- a/docs/source/zh-Hans/features/gradient_handler.md
+++ b/docs/source/zh-Hans/features/gradient_handler.md
@@ -25,7 +25,7 @@
 3. 实现 `handle_gradient`
 
 ```python
-from colossalai.registry import GRADIENT_HANDLER
+from colossalai.legacy.registry import GRADIENT_HANDLER
 from colossalai.legacy.engine.gradient_handler import BaseGradientHandler
 
 
diff --git a/examples/language/gpt/titans/dataset/webtext.py b/examples/language/gpt/titans/dataset/webtext.py
index 64f5944a9..fdfc57e9b 100644
--- a/examples/language/gpt/titans/dataset/webtext.py
+++ b/examples/language/gpt/titans/dataset/webtext.py
@@ -6,7 +6,7 @@ import torch
 from torch.utils.data import Dataset
 from transformers import GPT2Tokenizer
 
-from colossalai.registry import DATASETS
+from colossalai.legacy.registry import DATASETS
 
 
 @DATASETS.register_module
diff --git a/examples/language/gpt/titans/model/embed.py b/examples/language/gpt/titans/model/embed.py
index d825ae92a..668992901 100644
--- a/examples/language/gpt/titans/model/embed.py
+++ b/examples/language/gpt/titans/model/embed.py
@@ -8,11 +8,11 @@ from torch.nn.parameter import Parameter
 
 from colossalai.context import ParallelMode, seed
 from colossalai.core import global_context as gpc
+from colossalai.legacy.registry import LAYERS, LOSSES, MODELS
 from colossalai.nn.layer.base_layer import ParallelLayer
 from colossalai.nn.layer.parallel_1d._utils import gather_forward_split_backward, reduce_grad, reduce_input
 from colossalai.nn.layer.parallel_1d.layers import Linear1D_Row
 from colossalai.nn.layer.utils import divide
-from colossalai.registry import LAYERS, LOSSES, MODELS
 from colossalai.utils import get_current_device