From d202cc28c0e7707762bb5f94d944575b327ba903 Mon Sep 17 00:00:00 2001
From: Hongxin Liu <lhx0217@gmail.com>
Date: Tue, 9 Jan 2024 10:20:05 +0800
Subject: [PATCH] [npu] change device to accelerator api (#5239)

* update accelerator

* fix timer

* fix amp

* update

* fix

* update bug

* add error raise

* fix autocast

* fix set device

* remove doc accelerator

* update doc

* update doc

* update doc

* use nullcontext

* update cpu

* update null context

* change time limit for example

* udpate

* update

* update

* update

* [npu] polish accelerator code

---------

Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com>
Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com>
---
 .../workflows/example_check_on_dispatch.yml   |   2 +-
 .github/workflows/example_check_on_pr.yml     |   2 +-
 .../workflows/example_check_on_schedule.yml   |   2 +-
 applications/Chat/coati/trainer/ppo.py        |   4 +-
 .../coati/trainer/strategies/colossalai.py    |  13 +-
 applications/Colossal-LLaMA-2/train.py        |  56 ++--
 colossalai/accelerator/__init__.py            |   2 +
 colossalai/accelerator/api.py                 |  13 +-
 colossalai/accelerator/base_accelerator.py    | 236 ++++++++++++++-
 colossalai/accelerator/cpu_accelerator.py     | 277 ++++++++++++++++++
 colossalai/accelerator/cuda_accelerator.py    | 224 +++++++++++++-
 colossalai/accelerator/npu_accelerator.py     | 226 +++++++++++++-
 .../naive_amp/grad_scaler/base_grad_scaler.py |   4 +-
 .../grad_scaler/dynamic_grad_scaler.py        |  14 +-
 .../naive_amp/mixed_precision_mixin/fp16.py   |   4 +-
 .../auto_parallel/offload/amp_optimizer.py    |   6 +-
 colossalai/auto_parallel/offload/solver.py    |   7 +-
 .../booster/mixed_precision/fp16_torch.py     |   4 +-
 colossalai/booster/plugin/gemini_plugin.py    |   5 +-
 .../booster/plugin/hybrid_parallel_plugin.py  |  28 +-
 .../booster/plugin/low_level_zero_plugin.py   |   4 +-
 colossalai/initialize.py                      |  15 +-
 .../extensions/flash_attention/utils.py       |   6 +-
 colossalai/kernel/jit/option.py               |  22 +-
 colossalai/legacy/amp/torch_amp/torch_amp.py  |   5 +-
 colossalai/legacy/communication/p2p.py        |  10 +-
 colossalai/legacy/communication/ring.py       |   6 +-
 colossalai/legacy/communication/utils.py      |   6 +-
 .../legacy/engine/schedule/_base_schedule.py  |   6 +-
 .../engine/schedule/_pipeline_schedule.py     |   6 +-
 .../engine/schedule/_pipeline_schedule_v2.py  |   4 +-
 colossalai/legacy/initialize.py               |   6 +-
 .../nn/layer/colossalai_layer/embedding.py    |   4 +-
 .../layer/colossalai_layer/normalization.py   |   4 +-
 .../legacy/nn/layer/parallel_1d/layers.py     |  20 +-
 .../legacy/nn/layer/parallel_2d/_operation.py |   8 +-
 .../legacy/nn/layer/parallel_2d/layers.py     |  39 ++-
 .../nn/layer/parallel_2p5d/_operation.py      |  14 +-
 .../legacy/nn/layer/parallel_2p5d/layers.py   |  39 ++-
 .../legacy/nn/layer/parallel_3d/layers.py     |  53 +++-
 .../nn/layer/parallel_sequence/_operation.py  |  12 +-
 colossalai/legacy/nn/layer/vanilla/layers.py  |  28 +-
 colossalai/legacy/nn/loss/loss_2d.py          |   4 +-
 colossalai/legacy/nn/loss/loss_2p5d.py        |   4 +-
 colossalai/legacy/nn/loss/loss_3d.py          |   6 +-
 .../legacy/trainer/hooks/_metric_hook.py      |  22 +-
 .../legacy/utils/activation_checkpoint.py     |  10 +-
 colossalai/legacy/utils/memory.py             |   9 +-
 .../utils/profiler/legacy/comm_profiler.py    |   4 +-
 .../legacy/zero/gemini/stateful_tensor_mgr.py |   4 +-
 .../zero/gemini/tensor_placement_policy.py    |   6 +-
 .../bucket_tensor_shard_strategy.py           |   8 +-
 .../zero/shard_utils/tensor_shard_strategy.py |  10 +-
 .../zero/sharded_model/sharded_model_v2.py    |  14 +-
 .../legacy/zero/sharded_model/zero_hook.py    |   4 +-
 colossalai/moe/routers.py                     | 116 ++++----
 colossalai/moe/utils.py                       |  27 +-
 colossalai/pipeline/schedule/generate.py      |   4 +-
 .../pipeline/schedule/interleaved_pp.py       |   6 +-
 colossalai/pipeline/schedule/one_f_one_b.py   |   6 +-
 colossalai/shardformer/layer/utils.py         |  18 +-
 colossalai/testing/utils.py                   |  15 +-
 colossalai/utils/__init__.py                  |   7 -
 colossalai/utils/device.py                    | 223 --------------
 colossalai/utils/timer.py                     |   8 +-
 colossalai/zero/gemini/chunk/chunk.py         |  26 +-
 colossalai/zero/gemini/chunk/manager.py       |  12 +-
 colossalai/zero/gemini/gemini_ddp.py          |   7 +-
 colossalai/zero/gemini/gemini_optimizer.py    |  15 +-
 .../memory_tracer/chunk_memstats_collector.py |   4 +-
 .../gemini/memory_tracer/memory_monitor.py    |   6 +-
 colossalai/zero/gemini/placement_policy.py    |   8 +-
 colossalai/zero/gemini/utils.py               |   6 +-
 colossalai/zero/low_level/low_level_optim.py  |  33 +--
 .../train_gpt_using_hybrid_parallelism.md     |   5 +-
 .../train_gpt_using_hybrid_parallelism.md     |   3 +-
 .../roberta/pretraining/run_pretraining.py    |  11 +-
 .../dreambooth/train_dreambooth_colossalai.py |  12 +-
 .../train_dreambooth_colossalai_lora.py       |  12 +-
 examples/images/resnet/train.py               |   6 +-
 examples/images/vit/vit_benchmark.py          |   5 +-
 examples/inference/benchmark_llama.py         |  11 +-
 examples/inference/run_llama_inference.py     |   4 +-
 examples/language/bert/finetune.py            |  10 +-
 .../auto_offload/train_gpt_offload.py         |   4 +-
 .../language/gpt/gemini/train_gpt_demo.py     |   8 +-
 .../gpt/hybridparallelism/finetune.py         |  10 +-
 examples/language/gpt/titans/model/embed.py   |  14 +-
 examples/language/llama2/benchmark.py         |  11 +-
 examples/language/llama2/data_utils.py        |   6 +-
 examples/language/llama2/finetune.py          |   6 +-
 .../language/llama2/performance_evaluator.py  |   9 +-
 examples/language/llama2/pretrain.py          |   6 +-
 .../openmoe/benchmark/benchmark_cai.py        |  10 +-
 examples/language/openmoe/train.py            |   6 +-
 examples/language/palm/train.py               |   8 +-
 .../tutorial/new_api/cifar_resnet/train.py    |   6 +-
 examples/tutorial/new_api/cifar_vit/train.py  |   6 +-
 .../tutorial/new_api/glue_bert/finetune.py    |   4 +-
 examples/tutorial/opt/opt/run_clm.py          |  16 +-
 .../test_offload/test_perf.py                 |   4 +-
 .../test_compatibility_with_gemini.py         |   8 +-
 .../test_plugin/test_low_level_zero_plugin.py |   6 +-
 tests/test_legacy/test_comm/test_comm.py      |   8 +-
 .../test_1d/checks_1d/check_layer_1d.py       |  26 +-
 .../test_2d/checks_2d/check_layer_2d.py       |  36 +--
 .../test_2d/checks_2d/check_operation_2d.py   |  12 +-
 .../test_2p5d/checks_2p5d/check_layer_2p5d.py |  36 +--
 .../checks_2p5d/check_operation_2p5d.py       |  12 +-
 .../test_3d/checks_3d/check_layer_3d.py       |  30 +-
 .../checks_seq/check_layer_seq.py             |   8 +-
 .../test_trainer/test_pipeline/test_p2p.py    |   4 +-
 tests/test_legacy/test_utils/test_memory.py   |   6 +-
 .../test_utils/test_norm_gradient_clipping.py |   4 +-
 tests/test_moe/test_grad_handler.py           |   6 +-
 tests/test_moe/test_kernel.py                 |  10 +-
 tests/test_moe/test_moe_checkpoint.py         |   4 +-
 tests/test_moe/test_moe_ep_tp.py              |  62 ++--
 tests/test_moe/test_moe_group.py              |   4 +-
 tests/test_optimizer/test_adam_kernel.py      |   7 +-
 tests/test_pipeline/test_p2p_communication.py |   4 +-
 tests/test_zero/test_gemini/test_chunkv2.py   |   6 +-
 tests/test_zero/test_gemini/test_fwd_bwd.py   |   4 +-
 .../test_zero/test_gemini/test_grad_accum.py  |   4 +-
 tests/test_zero/test_gemini/test_inference.py |   8 +-
 tests/test_zero/test_gemini/test_optim.py     |   4 +-
 tests/test_zero/test_gemini/test_search.py    |   4 +-
 .../test_zero/test_low_level/test_grad_acc.py |   7 +-
 128 files changed, 1773 insertions(+), 868 deletions(-)
 create mode 100644 colossalai/accelerator/cpu_accelerator.py
 delete mode 100644 colossalai/utils/device.py

diff --git a/.github/workflows/example_check_on_dispatch.yml b/.github/workflows/example_check_on_dispatch.yml
index 9d3bd9a48..011a0ae03 100644
--- a/.github/workflows/example_check_on_dispatch.yml
+++ b/.github/workflows/example_check_on_dispatch.yml
@@ -47,7 +47,7 @@ jobs:
     container:
       image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
       options: --gpus all --rm -v /data/scratch/examples-data:/data/
-    timeout-minutes: 10
+    timeout-minutes: 15
     steps:
       - name: 📚 Checkout
         uses: actions/checkout@v3
diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml
index 5934704f4..608ae863f 100644
--- a/.github/workflows/example_check_on_pr.yml
+++ b/.github/workflows/example_check_on_pr.yml
@@ -79,7 +79,7 @@ jobs:
     container:
       image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
       options: --gpus all --rm -v /data/scratch/examples-data:/data/
-    timeout-minutes: 10
+    timeout-minutes: 15
     concurrency:
       group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-example-${{ matrix.directory }}
       cancel-in-progress: true
diff --git a/.github/workflows/example_check_on_schedule.yml b/.github/workflows/example_check_on_schedule.yml
index 5ed128c3e..4fcd1e3a9 100644
--- a/.github/workflows/example_check_on_schedule.yml
+++ b/.github/workflows/example_check_on_schedule.yml
@@ -35,7 +35,7 @@ jobs:
       matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
     container:
       image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
-    timeout-minutes: 10
+    timeout-minutes: 15
     steps:
       - name: 📚 Checkout
         uses: actions/checkout@v3
diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py
index d69666898..330e4e0e3 100644
--- a/applications/Chat/coati/trainer/ppo.py
+++ b/applications/Chat/coati/trainer/ppo.py
@@ -10,7 +10,7 @@ from torch.utils.data import DataLoader, DistributedSampler
 from tqdm import tqdm
 from transformers import PreTrainedTokenizerBase
 
-from colossalai.utils import get_current_device
+from colossalai.accelerator import get_accelerator
 
 from .base import OnPolicyTrainer
 from .callbacks import Callback
@@ -105,7 +105,7 @@ class PPOTrainer(OnPolicyTrainer):
         self.critic_optim = critic_optim
 
         self.offload_inference_models = offload_inference_models
-        self.device = get_current_device()
+        self.device = get_accelerator().get_current_device()
 
     def _before_fit(
         self,
diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py
index 7129edb06..95f016786 100644
--- a/applications/Chat/coati/trainer/strategies/colossalai.py
+++ b/applications/Chat/coati/trainer/strategies/colossalai.py
@@ -6,7 +6,6 @@ import torch.nn as nn
 import colossalai
 from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
 from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
-from colossalai.utils import get_current_device
 from colossalai.zero.gemini.gemini_ddp import GeminiDDP
 
 from .ddp import DDPStrategy
@@ -158,9 +157,19 @@ class GeminiStrategy(DDPStrategy):
 
         warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.")
 
+        # colossalai has changed api for get_current_device in 0.3.4 version or newer
+        try:
+            from colossalai.accelerator import get_accelerator
+
+            chunk_init_device = get_accelerator().get_current_device()
+        except:
+            from colossalai.utils import get_current_device
+
+            chunk_init_device = get_current_device()
+
         # NOTE: dist should be initialized before calling get_current_device()
         plugin_initializer = lambda: GeminiPlugin(
-            chunk_init_device=get_current_device(),
+            chunk_init_device=chunk_init_device,
             placement_policy=placement_policy,
             shard_param_frac=shard_param_frac,
             offload_optim_frac=offload_optim_frac,
diff --git a/applications/Colossal-LLaMA-2/train.py b/applications/Colossal-LLaMA-2/train.py
index 41b4ef031..92863e8e4 100644
--- a/applications/Colossal-LLaMA-2/train.py
+++ b/applications/Colossal-LLaMA-2/train.py
@@ -1,44 +1,37 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
 """
-Continual Pre-training of LLaMA-2 developed by Colossal-AI Team 
+Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
 """
 
-import json
 import argparse
+import json
 import os
 import resource
 from contextlib import nullcontext
-from tqdm import tqdm
 
 import torch
 import torch.distributed as dist
+from colossal_llama2.dataset.loader import (
+    DataCollatorForSupervisedDataset,
+    StatefulDistributedSampler,
+    load_tokenized_dataset,
+    setup_distributed_dataloader,
+)
+from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
+from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
+from colossal_llama2.utils.froze import freeze_non_embeds_parameters
 from torch.utils.tensorboard import SummaryWriter
-from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig
+from tqdm import tqdm
+from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
 
 import colossalai
 from colossalai.booster import Booster
-from colossalai.booster.plugin import (
-    GeminiPlugin,
-    LowLevelZeroPlugin,
-    HybridParallelPlugin,
-)
+from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
 from colossalai.cluster import DistCoordinator
 from colossalai.lazy import LazyInitContext
 from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
 from colossalai.nn.optimizer import HybridAdam
-from colossalai.utils import get_current_device
-
-from colossal_llama2.dataset.loader import (
-    load_tokenized_dataset,
-    setup_distributed_dataloader,
-    DataCollatorForSupervisedDataset,
-    StatefulDistributedSampler,
-)
-
-from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
-from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
-from colossal_llama2.utils.froze import freeze_non_embeds_parameters
 
 
 def get_model_numel(model: torch.nn.Module) -> int:
@@ -215,9 +208,18 @@ def main() -> None:
     # ======================================================
     # Initialize Model, Objective, Optimizer and LR Scheduler
     # ======================================================
-    init_ctx = (
-        LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
-    )
+
+    # colossalai has changed api for get_current_device in 0.3.4 version or newer
+    try:
+        from colossalai.accelerator import get_accelerator
+
+        current_device = get_accelerator().get_current_device()
+    except:
+        from colossalai.utils import get_current_device
+
+        current_device = get_current_device()
+
+    init_ctx = LazyInitContext(default_device=current_device) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
     with init_ctx:
         model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
         # Freeze part of parameters.
@@ -320,7 +322,7 @@ def main() -> None:
             initial=start_step,
         ) as pbar:
             for step, batch in pbar:
-                batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
+                batch = {k: v.to(current_device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
 
                 batch_output = model(**batch)
 
@@ -372,9 +374,7 @@ def main() -> None:
     # Final save.
     coordinator.print_on_master("Start saving final model checkpoint")
     booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
-    coordinator.print_on_master(
-        f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}"
-    )
+    coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
 
     coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
 
diff --git a/colossalai/accelerator/__init__.py b/colossalai/accelerator/__init__.py
index d144235d3..1405133af 100644
--- a/colossalai/accelerator/__init__.py
+++ b/colossalai/accelerator/__init__.py
@@ -1,5 +1,6 @@
 from .api import auto_set_accelerator, get_accelerator, set_accelerator
 from .base_accelerator import BaseAccelerator
+from .cpu_accelerator import CpuAccelerator
 from .cuda_accelerator import CudaAccelerator
 from .npu_accelerator import NpuAccelerator
 
@@ -10,4 +11,5 @@ __all__ = [
     "BaseAccelerator",
     "CudaAccelerator",
     "NpuAccelerator",
+    "CpuAccelerator",
 ]
diff --git a/colossalai/accelerator/api.py b/colossalai/accelerator/api.py
index 393340b71..02b3055d7 100644
--- a/colossalai/accelerator/api.py
+++ b/colossalai/accelerator/api.py
@@ -3,6 +3,7 @@ from collections import OrderedDict
 from typing import Union
 
 from .base_accelerator import BaseAccelerator
+from .cpu_accelerator import CpuAccelerator
 from .cuda_accelerator import CudaAccelerator
 from .npu_accelerator import NpuAccelerator
 
@@ -15,7 +16,7 @@ _ACCELERATOR = None
 # we use ordered dictionary here to associate the
 # order with device check priority
 # i.e. auto_set_accelerator will check cuda first
-_ACCELERATOR_MAPPING = OrderedDict(cuda=CudaAccelerator, npu=NpuAccelerator)
+_ACCELERATOR_MAPPING = OrderedDict(cuda=CudaAccelerator, npu=NpuAccelerator, cpu=CpuAccelerator)
 
 
 def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None:
@@ -43,19 +44,17 @@ def auto_set_accelerator() -> None:
     """
     global _ACCELERATOR
 
-    for _, accelerator_cls in _ACCELERATOR_MAPPING.items():
+    for accelerator_name, accelerator_cls in _ACCELERATOR_MAPPING.items():
         try:
             accelerator = accelerator_cls()
-            if accelerator.is_available():
+            if accelerator_name == "cpu" or accelerator.is_available():
                 _ACCELERATOR = accelerator
-            break
+                break
         except:
             pass
 
     if _ACCELERATOR is None:
-        raise RuntimeError(
-            f"No accelerator is available. Please check your environment. The list of accelerators we support is {list(_ACCELERATOR_MAPPING.keys())}"
-        )
+        raise RuntimeError("No accelerator is available.")
 
 
 def get_accelerator() -> BaseAccelerator:
diff --git a/colossalai/accelerator/base_accelerator.py b/colossalai/accelerator/base_accelerator.py
index 71d03b8d6..a550cd7a2 100644
--- a/colossalai/accelerator/base_accelerator.py
+++ b/colossalai/accelerator/base_accelerator.py
@@ -1,6 +1,7 @@
 #!/usr/bin/env python
+
 from abc import ABC, abstractmethod
-from typing import Union
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
 
 import torch
 
@@ -8,6 +9,8 @@ __all__ = ["BaseAccelerator"]
 
 
 class BaseAccelerator(ABC):
+    support_set_device: bool = True
+
     def __init__(self, name: str, communication_backend: str, is_synchronous: bool) -> None:
         self._name = name
         self._communication_backend = communication_backend
@@ -45,6 +48,12 @@ class BaseAccelerator(ABC):
     # =======================
     # device APIs
     # =======================
+    @abstractmethod
+    def get_current_device(self) -> torch.device:
+        """
+        Return the current device.
+        """
+
     @abstractmethod
     def current_device(self) -> int:
         """
@@ -52,7 +61,7 @@ class BaseAccelerator(ABC):
         """
 
     @abstractmethod
-    def set_device(self, device: Union[torch.device, int]) -> None:
+    def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
         """
         Bind the current process to a device.
         """
@@ -79,3 +88,226 @@ class BaseAccelerator(ABC):
         """
         Return the number of devices on the machine.
         """
+
+    def set_to_device(self, models: Any) -> Any:
+        """
+        Send model to device.
+
+        :param models: nn.module or a list of module
+        """
+        if isinstance(models, list) and len(models) > 1:
+            ret = []
+            for model in models:
+                ret.append(model.to(self.get_current_device()))
+            return ret
+        elif isinstance(models, list):
+            return models[0].to(self.get_current_device())
+        else:
+            return models.to(self.get_current_device())
+
+    @abstractmethod
+    def get_device_capability(self, device=None) -> Tuple[int, int]:
+        """
+        Gets the capability of a device.
+        """
+
+    @abstractmethod
+    def get_device_name(self, device=None) -> str:
+        """
+        Gets the name of a device.
+        """
+
+    @abstractmethod
+    def get_device_properties(self, device):
+        """
+        Gets the properties of a device.
+        """
+
+    @abstractmethod
+    def utilization(self, device=None) -> int:
+        """
+        Returns the percent of time over the past sample period during which one or more kernels was executing on the device as given by nvidia-smi or npu-smi, etc.
+        """
+
+    # =======================
+    # random number generator APIs
+    # =======================
+    @abstractmethod
+    def get_rng_state(self, device="cuda") -> torch.Tensor:
+        """
+        Returns the random number generator state of the specified device as a ByteTensor.
+        """
+
+    @abstractmethod
+    def get_rng_state_all(self) -> List[torch.Tensor]:
+        """
+        Returns a list of ByteTensor representing the random number states of all devices.
+        """
+
+    @abstractmethod
+    def set_rng_state(self, new_state: torch.ByteTensor, device: str = "cuda") -> None:
+        """
+        Sets the random number generator state of the specified device.
+        """
+
+    @abstractmethod
+    def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
+        """
+        Sets the random number generator state of all devices.
+        """
+
+    @abstractmethod
+    def manual_seed(self, seed: int) -> None:
+        """
+        Sets the seed for generating random numbers for the current device.
+        """
+
+    @abstractmethod
+    def manual_seed_all(self, seed: int) -> None:
+        """
+        Sets the seed for generating random numbers on all devices.
+        """
+
+    @abstractmethod
+    def seed(self) -> None:
+        """
+        Sets the seed for generating random numbers to a random number for the current device.
+        """
+
+    @abstractmethod
+    def seed_all(self) -> None:
+        """
+        Sets the seed for generating random numbers to a random number on all devices.
+        """
+
+    @abstractmethod
+    def initial_seed(self) -> int:
+        """
+        Returns the current random seed of the current device.
+        """
+
+    # =======================
+    # memory management APIs
+    # =======================
+    @abstractmethod
+    def empty_cache(self) -> None:
+        """
+        Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other device application and visible in nvidia-smi.
+        """
+
+    @abstractmethod
+    def memory_stats(self, device=None) -> Dict[str, Any]:
+        """
+        Returns a dictionary of CUDA memory allocator statistics for a given device.
+        """
+
+    @abstractmethod
+    def memory_summary(self, device=None, abbreviated=False) -> str:
+        """
+        Returns a human-readable printout of the current memory allocator statistics for a given device.
+        """
+
+    @abstractmethod
+    def memory_snapshot(self):
+        """
+        Returns a snapshot of the CUDA memory allocator state across all devices.
+        """
+
+    @abstractmethod
+    def memory_allocated(self, device=None) -> int:
+        """
+        Returns the current device memory occupied by tensors in bytes for a given device.
+        """
+
+    @abstractmethod
+    def max_memory_allocated(self, device=None) -> int:
+        """
+        Returns the maximum device memory occupied by tensors in bytes for a given device.
+        """
+
+    @abstractmethod
+    def reset_max_memory_allocated(self, device=None) -> None:
+        """
+        Resets the starting point in tracking maximum device memory occupied by tensors for a given device.
+        """
+
+    @abstractmethod
+    def reset_max_memory_cached(self, device=None) -> None:
+        """
+        Resets the starting point in tracking maximum device memory managed by the caching allocator for a given device.
+        """
+
+    @abstractmethod
+    def memory_reserved(self, device=None) -> int:
+        """
+        Returns the current device memory managed by the caching allocator in bytes for a given device.
+        """
+
+    @abstractmethod
+    def max_memory_reserved(self, device=None) -> int:
+        """
+        Returns the maximum device memory managed by the caching allocator in bytes for a given device.
+        """
+
+    @abstractmethod
+    def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
+        """
+        Set memory fraction for a process.
+        """
+
+    @abstractmethod
+    def reset_peak_memory_stats(self, device=None) -> None:
+        """
+        Resets the "peak" stats tracked by the device memory allocator.
+        """
+
+    # =======================
+    # streams and events APIs
+    # =======================
+
+    @abstractmethod
+    def Stream(self, device=None, priority=0, **kwargs):
+        """
+        A device stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details.
+        """
+
+    @abstractmethod
+    def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
+        """
+        device events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams.
+        """
+
+    @abstractmethod
+    def current_stream(self, device=None):
+        """
+        Returns the currently selected Stream for a given device.
+        """
+
+    @abstractmethod
+    def default_stream(self, device=None):
+        """
+        Returns the default Stream for a given device.
+        """
+
+    @abstractmethod
+    def set_stream(self, stream_):
+        """
+        Sets the current stream.This is a wrapper API to set the stream.
+        """
+
+    @abstractmethod
+    def stream(self, stream_):
+        """
+        Wrapper around the Context-manager StreamContext that selects a given stream.
+        """
+
+    # =======================
+    # amp APIs
+    # =======================
+    @abstractmethod
+    def autocast(
+        self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
+    ) -> Callable:
+        """
+        Return autocast function
+        """
diff --git a/colossalai/accelerator/cpu_accelerator.py b/colossalai/accelerator/cpu_accelerator.py
new file mode 100644
index 000000000..c1f01b4f7
--- /dev/null
+++ b/colossalai/accelerator/cpu_accelerator.py
@@ -0,0 +1,277 @@
+#!/usr/bin/env python
+
+import resource
+from contextlib import nullcontext
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import psutil
+import torch
+
+from .base_accelerator import BaseAccelerator
+
+__all__ = ["CpuAccelerator"]
+
+
+class CpuAccelerator(BaseAccelerator):
+    support_set_device: bool = False
+    """
+    Accelerator class for cpu.
+    """
+
+    def __init__(self):
+        super().__init__(name="cpu", communication_backend="gloo", is_synchronous=False)
+
+    # =======================
+    # device APIs
+    # =======================
+    def get_current_device(self) -> torch.device:
+        """
+        Return the current device.
+        """
+        return torch.device("cpu")
+
+    def current_device(self) -> int:
+        """
+        Return the current device index.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
+        """
+        Bind the current process to a device.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def get_device_name(self, device: Union[torch.device, int]) -> str:
+        """
+        Return the name of the device.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def synchronize(self, device: Union[torch.device, int] = None):
+        """
+        Synchronize the current process.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def is_available(self):
+        """
+        Check if the accelerator is available.
+        """
+        return True
+
+    def device_count(self):
+        """
+        Return the number of devices on the machine.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def get_device_capability(self, device=None) -> Tuple[int, int]:
+        """
+        Gets the cuda capability of a device.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def get_device_name(self, device=None) -> str:
+        """
+        Gets the name of a device.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def get_device_properties(self, device):
+        """
+        Gets the properties of a device.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def utilization(self, device=None) -> int:
+        """
+        Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    # =======================
+    # random number generator APIs
+    # =======================
+    def get_rng_state(self, device=None) -> torch.Tensor:
+        """
+        Returns the random number generator state of the specified GPU as a ByteTensor.
+        """
+        return torch.get_rng_state(device)
+
+    def get_rng_state_all(self) -> List[torch.Tensor]:
+        """
+        Returns a list of ByteTensor representing the random number states of all devices.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def set_rng_state(self, new_state: torch.ByteTensor, device: str = None) -> None:
+        """
+        Sets the random number generator state of the specified GPU.
+        """
+        torch.set_rng_state(new_state)
+
+    def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
+        """
+        Sets the random number generator state of all devices.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def manual_seed(self, seed: int) -> None:
+        """
+        Sets the seed for generating random numbers for the current GPU.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def manual_seed_all(self, seed: int) -> None:
+        """
+        Set the random seed for the all processes.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def seed(self) -> None:
+        """
+        Sets the seed for generating random numbers to a random number for the current GPU.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def seed_all(self) -> None:
+        """
+        Sets the seed for generating random numbers to a random number on all GPUs.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def initial_seed(self) -> int:
+        """
+        Returns the current random seed of the current GPU.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    # =======================
+    # memory management APIs
+    # =======================
+
+    def empty_cache(self) -> None:
+        """
+        Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def memory_stats(self, device=None) -> Dict[str, Any]:
+        """
+        Returns a dictionary of CUDA memory allocator statistics for a given device.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def memory_summary(self, device=None, abbreviated=False) -> str:
+        """
+        Returns a human-readable printout of the current memory allocator statistics for a given device.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def memory_snapshot(self):
+        """
+        Returns a snapshot of the CUDA memory allocator state across all devices.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def memory_allocated(self, device=None) -> int:
+        """
+        Returns the current GPU memory occupied by tensors in bytes for a given device.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def max_memory_allocated(self, device=None) -> int:
+        """
+        Returns the maximum GPU memory occupied by tensors in bytes for a given device.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def reset_max_memory_allocated(self, device=None) -> None:
+        """
+        Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def reset_max_memory_cached(self, device=None) -> None:
+        """
+        Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def memory_reserved(self, device=None) -> int:
+        """
+        Returns the current GPU memory managed by the caching allocator in bytes for a given device.
+        """
+        return psutil.Process().memory_info().rss
+
+    def max_memory_reserved(self, device=None) -> int:
+        """
+        Returns the maximum GPU memory managed by the caching allocator in bytes for a given device.
+        """
+        return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
+
+    def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
+        """
+        Set memory fraction for a process.
+        """
+        max_memory = int(psutil.virtual_memory().total * fraction)
+        _, hard = resource.getrlimit(resource.RLIMIT_AS)
+        resource.setrlimit(resource.RLIMIT_AS, (max_memory, hard))
+
+    def reset_peak_memory_stats(self, device=None) -> None:
+        """
+        Resets the "peak" stats tracked by the CUDA memory allocator.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    # =======================
+    # streams and events APIs
+    # =======================
+
+    def Stream(self, device=None, priority=0, **kwargs):
+        """
+        A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
+        """
+        CUDA events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def current_stream(self, device=None):
+        """
+        Returns the currently selected Stream for a given device.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def default_stream(self, device=None):
+        """
+        Returns the default Stream for a given device.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def set_stream(self, stream_):
+        """
+        Sets the current stream.This is a wrapper API to set the stream.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    def stream(self, stream_):
+        """
+        Wrapper around the Context-manager StreamContext that selects a given stream.
+        """
+        raise RuntimeError("this method is not supported for cpu accelerator")
+
+    # =======================
+    # amp APIs
+    # =======================
+    def autocast(
+        self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
+    ) -> Callable:
+        """
+        Return autocast function
+        """
+        return nullcontext
diff --git a/colossalai/accelerator/cuda_accelerator.py b/colossalai/accelerator/cuda_accelerator.py
index 72152834a..bdaf53bd5 100644
--- a/colossalai/accelerator/cuda_accelerator.py
+++ b/colossalai/accelerator/cuda_accelerator.py
@@ -1,7 +1,9 @@
 #!/usr/bin/env python
-from typing import Union
+
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
 
 import torch
+import torch.distributed as dist
 
 from .base_accelerator import BaseAccelerator
 
@@ -19,16 +21,26 @@ class CudaAccelerator(BaseAccelerator):
     # =======================
     # device APIs
     # =======================
+    def get_current_device(self) -> torch.device:
+        """
+        Return the current device.
+        """
+        return torch.device(f"cuda:{torch.cuda.current_device()}")
+
     def current_device(self) -> int:
         """
         Return the current device index.
         """
         return torch.cuda.current_device()
 
-    def set_device(self, device: Union[torch.device, int]) -> None:
+    def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
         """
         Bind the current process to a device.
         """
+        if device is None:
+            if not dist.is_initialized():
+                raise RuntimeError("Cannot get current device when distributed is not initialized.")
+            device = dist.get_rank() % self.device_count()
         torch.cuda.set_device(device)
 
     def get_device_name(self, device: Union[torch.device, int]) -> str:
@@ -54,3 +66,211 @@ class CudaAccelerator(BaseAccelerator):
         Return the number of devices on the machine.
         """
         return torch.cuda.device_count()
+
+    def get_device_capability(self, device=None) -> Tuple[int, int]:
+        """
+        Gets the cuda capability of a device.
+        """
+        return torch.cuda.get_device_capability(device)
+
+    def get_device_name(self, device=None) -> str:
+        """
+        Gets the name of a device.
+        """
+        return torch.cuda.get_device_name(device)
+
+    def get_device_properties(self, device):
+        """
+        Gets the properties of a device.
+        """
+        return torch.cuda.get_device_properties(device)
+
+    def utilization(self, device=None) -> int:
+        """
+        Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi
+        """
+        return torch.cuda.utilization(device)
+
+    # =======================
+    # random number generator APIs
+    # =======================
+    def get_rng_state(self, device="cuda") -> torch.Tensor:
+        """
+        Returns the random number generator state of the specified GPU as a ByteTensor.
+        """
+        return torch.cuda.get_rng_state(device)
+
+    def get_rng_state_all(self) -> List[torch.Tensor]:
+        """
+        Returns a list of ByteTensor representing the random number states of all devices.
+        """
+        return torch.cuda.get_rng_state_all()
+
+    def set_rng_state(self, new_state: torch.ByteTensor, device: str = "cuda") -> None:
+        """
+        Sets the random number generator state of the specified GPU.
+        """
+        torch.cuda.set_rng_state(new_state, device)
+
+    def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
+        """
+        Sets the random number generator state of all devices.
+        """
+        torch.cuda.set_rng_state_all(new_states)
+
+    def manual_seed(self, seed: int) -> None:
+        """
+        Sets the seed for generating random numbers for the current GPU.
+        """
+        torch.cuda.manual_seed(seed)
+
+    def manual_seed_all(self, seed: int) -> None:
+        """
+        Set the random seed for the all processes.
+        """
+        torch.cuda.manual_seed_all(seed)
+
+    def seed(self) -> None:
+        """
+        Sets the seed for generating random numbers to a random number for the current GPU.
+        """
+        torch.cuda.seed()
+
+    def seed_all(self) -> None:
+        """
+        Sets the seed for generating random numbers to a random number on all GPUs.
+        """
+        torch.cuda.seed_all()
+
+    def initial_seed(self) -> int:
+        """
+        Returns the current random seed of the current GPU.
+        """
+        return torch.cuda.initial_seed()
+
+    # =======================
+    # memory management APIs
+    # =======================
+
+    def empty_cache(self) -> None:
+        """
+        Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi.
+        """
+        torch.cuda.empty_cache()
+
+    def memory_stats(self, device=None) -> Dict[str, Any]:
+        """
+        Returns a dictionary of CUDA memory allocator statistics for a given device.
+        """
+        return torch.cuda.memory_stats(device=device)
+
+    def memory_summary(self, device=None, abbreviated=False) -> str:
+        """
+        Returns a human-readable printout of the current memory allocator statistics for a given device.
+        """
+        return torch.cuda.memory_summary(device=device, abbreviated=abbreviated)
+
+    def memory_snapshot(self):
+        """
+        Returns a snapshot of the CUDA memory allocator state across all devices.
+        """
+        return torch.cuda.memory_snapshot()
+
+    def memory_allocated(self, device=None) -> int:
+        """
+        Returns the current GPU memory occupied by tensors in bytes for a given device.
+        """
+        return torch.cuda.memory_allocated(device=device)
+
+    def max_memory_allocated(self, device=None) -> int:
+        """
+        Returns the maximum GPU memory occupied by tensors in bytes for a given device.
+        """
+        return torch.cuda.max_memory_allocated(device=device)
+
+    def reset_max_memory_allocated(self, device=None) -> None:
+        """
+        Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device.
+        """
+        torch.cuda.reset_max_memory_allocated(device=device)
+
+    def reset_max_memory_cached(self, device=None) -> None:
+        """
+        Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
+        """
+        torch.cuda.reset_max_memory_cached(device=device)
+
+    def memory_reserved(self, device=None) -> int:
+        """
+        Returns the current GPU memory managed by the caching allocator in bytes for a given device.
+        """
+        return torch.cuda.memory_reserved(device=device)
+
+    def max_memory_reserved(self, device=None) -> int:
+        """
+        Returns the maximum GPU memory managed by the caching allocator in bytes for a given device.
+        """
+        return torch.cuda.max_memory_reserved(device=device)
+
+    def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
+        """
+        Set memory fraction for a process.
+        """
+        torch.cuda.set_per_process_memory_fraction(fraction, device=device)
+
+    def reset_peak_memory_stats(self, device=None) -> None:
+        """
+        Resets the "peak" stats tracked by the CUDA memory allocator.
+        """
+        torch.cuda.reset_peak_memory_stats(device=device)
+
+    # =======================
+    # streams and events APIs
+    # =======================
+
+    def Stream(self, device=None, priority=0, **kwargs):
+        """
+        A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details.
+        """
+        return torch.cuda.Stream(device, priority, **kwargs)
+
+    def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
+        """
+        CUDA events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams.
+        """
+        return torch.cuda.Event(enable_timing, blocking, interprocess)
+
+    def current_stream(self, device=None):
+        """
+        Returns the currently selected Stream for a given device.
+        """
+        return torch.cuda.current_stream(device)
+
+    def default_stream(self, device=None):
+        """
+        Returns the default Stream for a given device.
+        """
+        return torch.cuda.default_stream(device)
+
+    def set_stream(self, stream_):
+        """
+        Sets the current stream.This is a wrapper API to set the stream.
+        """
+        torch.cuda.set_stream(stream_)
+
+    def stream(self, stream_):
+        """
+        Wrapper around the Context-manager StreamContext that selects a given stream.
+        """
+        return torch.cuda.stream(stream_)
+
+    # =======================
+    # amp APIs
+    # =======================
+    def autocast(
+        self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
+    ) -> Callable:
+        """
+        Return autocast function
+        """
+        return torch.cuda.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
diff --git a/colossalai/accelerator/npu_accelerator.py b/colossalai/accelerator/npu_accelerator.py
index a8bba6eaf..b3575dbfe 100644
--- a/colossalai/accelerator/npu_accelerator.py
+++ b/colossalai/accelerator/npu_accelerator.py
@@ -1,13 +1,17 @@
 #!/usr/bin/env python
 
-from typing import Union
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
 
 import torch
+import torch.distributed as dist
 
 from .base_accelerator import BaseAccelerator
 
+IS_NPU_AVAILABLE = False
 try:
     import torch_npu  # noqa
+
+    IS_NPU_AVAILABLE = True
 except ImportError:
     pass
 
@@ -26,16 +30,26 @@ class NpuAccelerator(BaseAccelerator):
     # =======================
     # device APIs
     # =======================
+    def get_current_device(self) -> torch.device:
+        """
+        Return the current device.
+        """
+        return torch.device(f"npu:{torch.npu.current_device()}")
+
     def current_device(self) -> int:
         """
         Return the current device index.
         """
         return torch.npu.current_device()
 
-    def set_device(self, device: Union[torch.device, int]) -> None:
+    def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
         """
         Bind the current process to a device.
         """
+        if device is None:
+            if not dist.is_initialized():
+                raise RuntimeError("Cannot get current device when distributed is not initialized.")
+            device = dist.get_rank() % self.device_count()
         torch.npu.set_device(device)
 
     def get_device_name(self, device: Union[torch.device, int]) -> str:
@@ -61,3 +75,211 @@ class NpuAccelerator(BaseAccelerator):
         Return the number of devices on the machine.
         """
         return torch.npu.device_count()
+
+    def get_device_capability(self, device=None) -> Tuple[int, int]:
+        """
+        Gets the npu capability of a device.
+        """
+        return torch.npu.get_device_capability(device)
+
+    def get_device_name(self, device=None) -> str:
+        """
+        Gets the name of a device.
+        """
+        return torch.npu.get_device_name(device)
+
+    def get_device_properties(self, device):
+        """
+        Gets the properties of a device.
+        """
+        return torch.npu.get_device_properties(device)
+
+    def utilization(self, device=None) -> int:
+        """
+        Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi
+        """
+        return torch.npu.utilization(device)
+
+    # =======================
+    # random number generator APIs
+    # =======================
+    def get_rng_state(self, device="npu") -> torch.Tensor:
+        """
+        Returns the random number generator state of the specified GPU as a ByteTensor.
+        """
+        return torch.npu.get_rng_state(device)
+
+    def get_rng_state_all(self) -> List[torch.Tensor]:
+        """
+        Returns a list of ByteTensor representing the random number states of all devices.
+        """
+        return torch.npu.get_rng_state_all()
+
+    def set_rng_state(self, new_state: torch.ByteTensor, device: str = "npu") -> None:
+        """
+        Sets the random number generator state of the specified GPU.
+        """
+        torch.npu.set_rng_state(new_state, device)
+
+    def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
+        """
+        Sets the random number generator state of all devices.
+        """
+        torch.npu.set_rng_state_all(new_states)
+
+    def manual_seed(self, seed: int) -> None:
+        """
+        Sets the seed for generating random numbers for the current GPU.
+        """
+        torch.npu.manual_seed(seed)
+
+    def manual_seed_all(self, seed: int) -> None:
+        """
+        Set the random seed for the all processes.
+        """
+        torch.npu.manual_seed_all(seed)
+
+    def seed(self) -> None:
+        """
+        Sets the seed for generating random numbers to a random number for the current GPU.
+        """
+        torch.npu.seed()
+
+    def seed_all(self) -> None:
+        """
+        Sets the seed for generating random numbers to a random number on all GPUs.
+        """
+        torch.npu.seed_all()
+
+    def initial_seed(self) -> int:
+        """
+        Returns the current random seed of the current GPU.
+        """
+        return torch.npu.initial_seed()
+
+    # =======================
+    # memory management APIs
+    # =======================
+
+    def empty_cache(self) -> None:
+        """
+        Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi.
+        """
+        torch.npu.empty_cache()
+
+    def memory_stats(self, device=None) -> Dict[str, Any]:
+        """
+        Returns a dictionary of npu memory allocator statistics for a given device.
+        """
+        return torch.npu.memory_stats(device=device)
+
+    def memory_summary(self, device=None, abbreviated=False) -> str:
+        """
+        Returns a human-readable printout of the current memory allocator statistics for a given device.
+        """
+        return torch.npu.memory_summary(device=device, abbreviated=abbreviated)
+
+    def memory_snapshot(self):
+        """
+        Returns a snapshot of the npu memory allocator state across all devices.
+        """
+        return torch.npu.memory_snapshot()
+
+    def memory_allocated(self, device=None) -> int:
+        """
+        Returns the current GPU memory occupied by tensors in bytes for a given device.
+        """
+        return torch.npu.memory_allocated(device=device)
+
+    def max_memory_allocated(self, device=None) -> int:
+        """
+        Returns the maximum GPU memory occupied by tensors in bytes for a given device.
+        """
+        return torch.npu.max_memory_allocated(device=device)
+
+    def reset_max_memory_allocated(self, device=None) -> None:
+        """
+        Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device.
+        """
+        torch.npu.reset_max_memory_allocated(device=device)
+
+    def reset_max_memory_cached(self, device=None) -> None:
+        """
+        Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
+        """
+        torch.npu.reset_max_memory_cached(device=device)
+
+    def memory_reserved(self, device=None) -> int:
+        """
+        Returns the current GPU memory managed by the caching allocator in bytes for a given device.
+        """
+        return torch.npu.memory_reserved(device=device)
+
+    def max_memory_reserved(self, device=None) -> int:
+        """
+        Returns the maximum GPU memory managed by the caching allocator in bytes for a given device.
+        """
+        return torch.npu.max_memory_reserved(device=device)
+
+    def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
+        """
+        Set memory fraction for a process.
+        """
+        torch.npu.set_per_process_memory_fraction(fraction, device=device)
+
+    def reset_peak_memory_stats(self, device=None) -> None:
+        """
+        Resets the "peak" stats tracked by the npu memory allocator.
+        """
+        torch.npu.reset_peak_memory_stats(device=device)
+
+    # =======================
+    # streams and events APIs
+    # =======================
+
+    def Stream(self, device=None, priority=0, **kwargs):
+        """
+        A npu stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See npu-semantics for details.
+        """
+        return torch.npu.Stream(device, priority, **kwargs)
+
+    def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
+        """
+        npu events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize npu streams.
+        """
+        return torch.npu.Event(enable_timing, blocking, interprocess)
+
+    def current_stream(self, device=None):
+        """
+        Returns the currently selected Stream for a given device.
+        """
+        return torch.npu.current_stream(device)
+
+    def default_stream(self, device=None):
+        """
+        Returns the default Stream for a given device.
+        """
+        return torch.npu.default_stream(device)
+
+    def set_stream(self, stream_):
+        """
+        Sets the current stream.This is a wrapper API to set the stream.
+        """
+        torch.npu.set_stream(stream_)
+
+    def stream(self, stream_):
+        """
+        Wrapper around the Context-manager StreamContext that selects a given stream.
+        """
+        return torch.npu.stream(stream_)
+
+    # =======================
+    # amp APIs
+    # =======================
+    def autocast(
+        self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
+    ) -> Callable:
+        """
+        Return autocast function
+        """
+        return torch.npu.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
diff --git a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
index 439d13dcf..fc4c884d4 100644
--- a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
+++ b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
@@ -7,8 +7,8 @@ from typing import Dict
 import torch
 from torch import Tensor
 
+from colossalai.accelerator import get_accelerator
 from colossalai.logging import get_dist_logger
-from colossalai.utils.device import get_current_device
 
 __all__ = ["BaseGradScaler"]
 
@@ -23,7 +23,7 @@ class BaseGradScaler(ABC):
 
     def __init__(self, initial_scale: float, verbose: bool):
         assert initial_scale > 0
-        self._scale = torch.tensor([initial_scale], device=get_current_device(), dtype=torch.float)
+        self._scale = torch.tensor([initial_scale], device=get_accelerator().get_current_device(), dtype=torch.float)
         self._verbose = verbose
 
         if self._verbose:
diff --git a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
index 86ba919ee..5cd8035d7 100644
--- a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
+++ b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
@@ -5,7 +5,7 @@ from typing import Optional
 
 import torch
 
-from colossalai.utils.device import get_current_device
+from colossalai.accelerator import get_accelerator
 
 from .base_grad_scaler import BaseGradScaler
 
@@ -37,14 +37,20 @@ class DynamicGradScaler(BaseGradScaler):
         hysteresis: int = 2,
         verbose: bool = False,
     ):
+        a = get_accelerator()
+        a.device_count()
         super().__init__(initial_scale, verbose)
         if min_scale:
-            self._min_scale = torch.tensor([min_scale], device=get_current_device(), dtype=torch.float)
+            self._min_scale = torch.tensor(
+                [min_scale], device=get_accelerator().get_current_device(), dtype=torch.float
+            )
         else:
             self._min_scale = None
 
         if max_scale:
-            self._max_scale = torch.tensor([max_scale], device=get_current_device(), dtype=torch.float)
+            self._max_scale = torch.tensor(
+                [max_scale], device=get_accelerator().get_current_device(), dtype=torch.float
+            )
         else:
             self._max_scale = None
 
@@ -117,7 +123,7 @@ class DynamicGradScaler(BaseGradScaler):
         return state_dict
 
     def load_state_dict(self, state_dict):
-        self._scale = state_dict["scale"].to(get_current_device())
+        self._scale = state_dict["scale"].to(get_accelerator().get_current_device())
         self._growth_factor = state_dict["growth_factor"]
         self._backoff_factor = state_dict["backoff_factor"]
         self._hysteresis = state_dict["hysteresis"]
diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py
index 9ce272356..2e7c8a281 100644
--- a/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py
+++ b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py
@@ -5,8 +5,8 @@ import torch
 import torch.distributed as dist
 from torch import Tensor
 
+from colossalai.accelerator import get_accelerator
 from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
-from colossalai.utils import get_current_device
 
 from .base import MixedPrecisionMixin
 
@@ -40,7 +40,7 @@ class FP16MixedPrecisionMixin(MixedPrecisionMixin):
             max_scale=max_scale,
         )
         self.optim_state = OptimState.UNSCALED
-        self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device())
+        self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device())
 
     @property
     def loss_scale(self) -> float:
diff --git a/colossalai/auto_parallel/offload/amp_optimizer.py b/colossalai/auto_parallel/offload/amp_optimizer.py
index 601bf2926..fe8439269 100644
--- a/colossalai/auto_parallel/offload/amp_optimizer.py
+++ b/colossalai/auto_parallel/offload/amp_optimizer.py
@@ -4,10 +4,10 @@ from typing import Dict, Tuple
 import torch
 from torch.optim import Optimizer
 
+from colossalai.accelerator import get_accelerator
 from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
 from colossalai.interface import OptimizerWrapper
 from colossalai.logging import get_dist_logger
-from colossalai.utils import get_current_device
 
 from .base_offload_module import BaseOffloadModule
 from .region import Region
@@ -79,7 +79,9 @@ class AMPOptimizer(OptimizerWrapper):
             hysteresis=hysteresis,
             max_scale=max_scale,
         )
-        self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
+        self._found_overflow: torch.Tensor = torch.zeros(
+            1, dtype=torch.int64, device=get_accelerator().get_current_device()
+        )
         self._logger = get_dist_logger()
 
     def _set_grad_ptr(self):
diff --git a/colossalai/auto_parallel/offload/solver.py b/colossalai/auto_parallel/offload/solver.py
index a6628e29c..3ad210de9 100644
--- a/colossalai/auto_parallel/offload/solver.py
+++ b/colossalai/auto_parallel/offload/solver.py
@@ -11,7 +11,7 @@ except:
 import torch
 from torch.fx.node import Node
 
-from colossalai.utils.device import get_current_device
+from colossalai.accelerator import get_accelerator
 
 from .region import Region
 from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator
@@ -57,7 +57,10 @@ class Solver(ABC):
         if memory_budget > 0:
             self.memory_budget = memory_budget * self.error_factor
         else:
-            self.memory_budget = torch.cuda.get_device_properties(get_current_device()).total_memory * self.error_factor
+            self.memory_budget = (
+                torch.cuda.get_device_properties(get_accelerator().get_current_device()).total_memory
+                * self.error_factor
+            )
 
         self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth()
         self.comp_power: float = self._extract_computing_power()
diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py
index 443c4094c..c757a878d 100644
--- a/colossalai/booster/mixed_precision/fp16_torch.py
+++ b/colossalai/booster/mixed_precision/fp16_torch.py
@@ -5,8 +5,8 @@ import torch.nn as nn
 from torch import Tensor
 from torch.optim import Optimizer
 
+from colossalai.accelerator import get_accelerator
 from colossalai.interface import ModelWrapper, OptimizerWrapper
-from colossalai.utils.device import autocast
 
 from .mixed_precision_base import MixedPrecision
 
@@ -89,7 +89,7 @@ class TorchAMPModule(ModelWrapper):
         super().__init__(module)
 
     def forward(self, *args, **kwargs):
-        with autocast():
+        with get_accelerator().autocast():
             return self.module(*args, **kwargs)
 
 
diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py
index 261080dc9..d6610a3e1 100644
--- a/colossalai/booster/plugin/gemini_plugin.py
+++ b/colossalai/booster/plugin/gemini_plugin.py
@@ -12,6 +12,7 @@ from torch.optim import Optimizer
 from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
 from torch.utils.data import DataLoader
 
+from colossalai.accelerator import IS_NPU_AVAILABLE, get_accelerator
 from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
 from colossalai.checkpoint_io.utils import (
     get_model_base_filenames,
@@ -24,8 +25,6 @@ from colossalai.checkpoint_io.utils import (
 from colossalai.cluster import DistCoordinator, ProcessGroupMesh
 from colossalai.interface import ModelWrapper, OptimizerWrapper
 from colossalai.shardformer import ShardConfig, ShardFormer
-from colossalai.utils import get_current_device
-from colossalai.utils.device import IS_NPU_AVAILABLE
 from colossalai.zero import GeminiDDP, GeminiOptimizer
 from colossalai.zero.gemini.memory_tracer import MemStats
 
@@ -367,7 +366,7 @@ class GeminiPlugin(DPPluginBase):
             assert placement_policy == "static", "NPU only supports static placement policy"
         self.gemini_config = dict(
             chunk_config_dict=chunk_config_dict,
-            chunk_init_device=(chunk_init_device or get_current_device()),
+            chunk_init_device=(chunk_init_device or get_accelerator().get_current_device()),
             placement_policy=placement_policy,
             enable_gradient_accumulation=enable_gradient_accumulation,
             shard_param_frac=shard_param_frac,
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index bbc36ceab..2cc9e19bf 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -18,6 +18,7 @@ from torch.utils._pytree import tree_map
 from torch.utils.data import DataLoader
 from torch.utils.data.distributed import DistributedSampler
 
+from colossalai.accelerator import get_accelerator
 from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
 from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
 from colossalai.cluster import ProcessGroupMesh
@@ -29,7 +30,6 @@ from colossalai.shardformer.layer.utils import SeqParallelUtils
 from colossalai.shardformer.policies.base_policy import Policy
 from colossalai.tensor.d_tensor.api import is_distributed_tensor
 from colossalai.zero.low_level import LowLevelZeroOptimizer
-from colossalai.utils.device import get_current_device
 
 from .pp_plugin_base import PipelinePluginBase
 
@@ -82,7 +82,7 @@ class HybridParallelModule(ModelWrapper):
             self.mixed_precision = torch.bfloat16
         if self.mixed_precision is not None:
             module = module.to(self.mixed_precision)
-        module = module.to(get_current_device())
+        module = module.to(get_accelerator().get_current_device())
 
         # setting input type cast when using mixed precision
         self.convert_fn = None
@@ -346,7 +346,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
 
         if norm_type == inf:
             total_norm = max(grad.data.abs().max() for grad in gradients)
-            total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
+            total_norm_cuda = torch.tensor(
+                [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
+            )
             if self.tp_size > 1:
                 dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
             if self.pp_size > 1:
@@ -385,7 +387,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
 
                 total_norm_exponentiated += grad_norm_exponentiated
 
-            total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32)
+            total_norm_exponentiated_cuda = torch.tensor(
+                [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
+            )
             if self.tp_size > 1:
                 # compute norm in tp process group
                 dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
@@ -543,7 +547,9 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
             # so we need to calculate the norm of 'tp' and 'pp' gradients.
             total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type)
 
-            total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
+            total_norm_cuda = torch.tensor(
+                [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
+            )
 
             if self.tp_size > 1:
                 dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
@@ -586,7 +592,9 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
 
                 total_norm_exponentiated += grad_norm_exponentiated
 
-            total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32)
+            total_norm_exponentiated_cuda = torch.tensor(
+                [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
+            )
             if self.tp_size > 1:
                 # compute norm in tp process group
                 dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
@@ -798,7 +806,9 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
             # so we only need to calculate the norm 'tp' of 'pp' gradients.
             total_norm = super()._compute_grad_norm(gradients, norm_type)
 
-            total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
+            total_norm_cuda = torch.tensor(
+                [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
+            )
 
             if tp_size > 1:
                 dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
@@ -837,7 +847,9 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
 
                 total_norm_exponentiated += grad_norm_exponentiated
 
-            total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32)
+            total_norm_exponentiated_cuda = torch.tensor(
+                [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
+            )
             if dp_size > 1:
                 # compute norm in dp process group
                 dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg)
diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py
index 89102820c..d21496f0b 100644
--- a/colossalai/booster/plugin/low_level_zero_plugin.py
+++ b/colossalai/booster/plugin/low_level_zero_plugin.py
@@ -12,6 +12,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
 from torch.utils._pytree import tree_map
 from torch.utils.data import DataLoader
 
+from colossalai.accelerator import get_accelerator
 from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
 from colossalai.checkpoint_io.utils import (
     get_optimizer_base_filenames,
@@ -24,7 +25,6 @@ from colossalai.checkpoint_io.utils import (
     sharded_optimizer_loading_epilogue,
 )
 from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
-from colossalai.utils import get_current_device
 from colossalai.zero import LowLevelZeroOptimizer
 
 from .dp_plugin_base import DPPluginBase
@@ -52,7 +52,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
             self.dtype = torch.bfloat16
         if self.dtype is not None:
             module = module.to(self.dtype)
-        module = module.to(get_current_device())
+        module = module.to(get_accelerator().get_current_device())
         self.module = module
         self.convert_fn = None
         if self.dtype is not None:
diff --git a/colossalai/initialize.py b/colossalai/initialize.py
index 25076b742..aaeaad382 100644
--- a/colossalai/initialize.py
+++ b/colossalai/initialize.py
@@ -6,12 +6,12 @@ import warnings
 from pathlib import Path
 from typing import Dict, Union
 
-import torch
 import torch.distributed as dist
 
+from colossalai.accelerator import get_accelerator
 from colossalai.context import Config
 from colossalai.logging import get_dist_logger
-from colossalai.utils import IS_NPU_AVAILABLE, set_device, set_seed
+from colossalai.utils import set_seed
 
 
 def launch(
@@ -47,17 +47,18 @@ def launch(
     if rank == 0:
         warnings.warn("`config` is deprecated and will be removed soon.")
 
-    if IS_NPU_AVAILABLE and backend == "nccl":
-        backend = "hccl"
+    cur_accelerator = get_accelerator()
+
+    backend = cur_accelerator.communication_backend
 
     # init default process group
     init_method = f"tcp://[{host}]:{port}"
     dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
 
     # set cuda device
-    if torch.cuda.is_available() or IS_NPU_AVAILABLE:
-        # if local rank is not given, calculate automatically
-        set_device(local_rank)
+    # if local rank is not given, calculate automatically
+    if cur_accelerator.support_set_device:
+        cur_accelerator.set_device(local_rank)
 
     set_seed(seed)
 
diff --git a/colossalai/kernel/extensions/flash_attention/utils.py b/colossalai/kernel/extensions/flash_attention/utils.py
index 0eab9e89f..06fef491f 100644
--- a/colossalai/kernel/extensions/flash_attention/utils.py
+++ b/colossalai/kernel/extensions/flash_attention/utils.py
@@ -6,7 +6,7 @@ import torch
 import torch.nn.functional as F
 from einops import rearrange
 
-from colossalai.utils.device import get_current_device
+from colossalai.accelerator import get_accelerator
 
 
 class Unpad(torch.autograd.Function):
@@ -70,7 +70,9 @@ class SeqLenInfo:
     cu_seqlens: torch.Tensor = None
 
     @staticmethod
-    def materialize(attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_current_device()):
+    def materialize(
+        attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device()
+    ):
         if attn_mask is not None:
             indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
             seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
diff --git a/colossalai/kernel/jit/option.py b/colossalai/kernel/jit/option.py
index 8bebad894..d392649a6 100644
--- a/colossalai/kernel/jit/option.py
+++ b/colossalai/kernel/jit/option.py
@@ -1,7 +1,7 @@
 import torch
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear
-from colossalai.utils import get_current_device
 
 from .bias_dropout_add import bias_dropout_add_fused_train
 from .bias_gelu import bias_gelu_impl
@@ -46,11 +46,13 @@ def warmup_jit_fusion(
 ):
     """Compile JIT functions before the main training steps"""
 
-    embed = Embedding(vocab_size, hidden_size).to(get_current_device())
-    linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device())
-    linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_current_device())
+    embed = Embedding(vocab_size, hidden_size).to(get_accelerator().get_current_device())
+    linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_accelerator().get_current_device())
+    linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_accelerator().get_current_device())
 
-    x = torch.randint(vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_current_device())
+    x = torch.randint(
+        vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_accelerator().get_current_device()
+    )
     x = embed(x)
     y, y_bias = linear_1(x)
     z, z_bias = linear_2(y)
@@ -58,8 +60,8 @@ def warmup_jit_fusion(
     # prop and recomputation
     for bias_grad, input_grad in zip([True, True], [False, True]):
         for _ in range(10):
-            bias = torch.rand_like(y_bias, dtype=dtype, device=get_current_device())
-            input_ = torch.rand_like(y, dtype=dtype, device=get_current_device())
+            bias = torch.rand_like(y_bias, dtype=dtype, device=get_accelerator().get_current_device())
+            input_ = torch.rand_like(y, dtype=dtype, device=get_accelerator().get_current_device())
             bias.requires_grad, input_.requires_grad = bias_grad, input_grad
             bias_gelu_impl(input_, bias)
 
@@ -69,9 +71,9 @@ def warmup_jit_fusion(
     # prop and recomputation
     for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
         for _ in range(10):
-            input_ = torch.rand_like(z, dtype=dtype, device=get_current_device())
-            residual = torch.rand_like(x, dtype=dtype, device=get_current_device())
-            bias = torch.rand_like(z_bias, dtype=dtype, device=get_current_device())
+            input_ = torch.rand_like(z, dtype=dtype, device=get_accelerator().get_current_device())
+            residual = torch.rand_like(x, dtype=dtype, device=get_accelerator().get_current_device())
+            bias = torch.rand_like(z_bias, dtype=dtype, device=get_accelerator().get_current_device())
             input_.requires_grad = input_grad
             bias.requires_grad = bias_grad
             residual.requires_grad = residual_grad
diff --git a/colossalai/legacy/amp/torch_amp/torch_amp.py b/colossalai/legacy/amp/torch_amp/torch_amp.py
index 0a8d09be2..08f867eee 100644
--- a/colossalai/legacy/amp/torch_amp/torch_amp.py
+++ b/colossalai/legacy/amp/torch_amp/torch_amp.py
@@ -1,18 +1,19 @@
 #!/usr/bin/env python
 # -*- encoding: utf-8 -*-
 
-from colossalai.utils.device import autocast
-
 import torch.nn as nn
 from torch import Tensor
 from torch.nn.modules.loss import _Loss
 from torch.optim import Optimizer
 
+from colossalai.accelerator import get_accelerator
 from colossalai.interface import OptimizerWrapper
 from colossalai.legacy.utils import clip_grad_norm_fp32
 
 from ._grad_scaler import GradScaler
 
+autocast = get_accelerator().autocast
+
 
 class TorchAMPOptimizer(OptimizerWrapper):
     """A wrapper class which integrate Pytorch AMP with an optimizer
diff --git a/colossalai/legacy/communication/p2p.py b/colossalai/legacy/communication/p2p.py
index 19c3919b6..cf0bd4ba2 100644
--- a/colossalai/legacy/communication/p2p.py
+++ b/colossalai/legacy/communication/p2p.py
@@ -8,9 +8,9 @@ from typing import List, Tuple, Union
 import torch
 import torch.distributed as dist
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.context.parallel_mode import ParallelMode
 from colossalai.legacy.core import global_context as gpc
-from colossalai.utils import get_current_device
 
 from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks
 
@@ -43,12 +43,16 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) ->
 def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors):
     if isinstance(recv_shapes, torch.Size):
         recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors)
-        buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
+        buffer_recv = torch.empty(
+            recv_chunk_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=dtype
+        )
         return buffer_recv, recv_split
     buffer_recv = []
     for recv_shape in recv_shapes:
         recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors)
-        tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
+        tensor_recv = torch.empty(
+            recv_chunk_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=dtype
+        )
         buffer_recv.append(tensor_recv)
     return buffer_recv, recv_split
 
diff --git a/colossalai/legacy/communication/ring.py b/colossalai/legacy/communication/ring.py
index a61dae56c..792a15abd 100644
--- a/colossalai/legacy/communication/ring.py
+++ b/colossalai/legacy/communication/ring.py
@@ -3,9 +3,9 @@
 
 import torch
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.context.parallel_mode import ParallelMode
 from colossalai.legacy.core import global_context as gpc
-from colossalai.utils import get_current_device, synchronize
 
 
 def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> torch.Tensor:
@@ -29,7 +29,7 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) ->
     current_rank = gpc.get_global_rank()
 
     tensor_recv_prev = torch.empty(
-        buffer_shape, requires_grad=True, device=get_current_device(), dtype=tensor_send_next.dtype
+        buffer_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=tensor_send_next.dtype
     )
 
     # send to next rank
@@ -52,6 +52,6 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) ->
         req.wait()
 
     # To protect against race condition when using batch_isend_irecv().
-    synchronize()
+    get_accelerator().synchronize()
 
     return tensor_recv_prev
diff --git a/colossalai/legacy/communication/utils.py b/colossalai/legacy/communication/utils.py
index 6d77f3753..0b7c0eb74 100644
--- a/colossalai/legacy/communication/utils.py
+++ b/colossalai/legacy/communication/utils.py
@@ -3,9 +3,9 @@ from typing import List, Tuple, Union
 import torch
 import torch.distributed as dist
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.context.parallel_mode import ParallelMode
 from colossalai.legacy.core import global_context as gpc
-from colossalai.utils import get_current_device
 
 TensorShape = Union[torch.Size, List[int], Tuple[int]]
 
@@ -35,7 +35,7 @@ def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
         if next_rank is None:
             next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
 
-        tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
+        tensor_kwargs = {"dtype": torch.long, "device": get_accelerator().get_current_device()}
         if isinstance(obj, torch.Tensor):
             send_obj_nums = torch.tensor(1, **tensor_kwargs)
             dist.send(send_obj_nums, next_rank)
@@ -74,7 +74,7 @@ def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size:
         if prev_rank is None:
             prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
 
-        tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
+        tensor_kwargs = {"dtype": torch.long, "device": get_accelerator().get_current_device()}
         recv_obj_nums = torch.empty((), **tensor_kwargs)
         dist.recv(recv_obj_nums, prev_rank)
         if recv_obj_nums.item() == 1:
diff --git a/colossalai/legacy/engine/schedule/_base_schedule.py b/colossalai/legacy/engine/schedule/_base_schedule.py
index 4a3ccfda1..9b2913442 100644
--- a/colossalai/legacy/engine/schedule/_base_schedule.py
+++ b/colossalai/legacy/engine/schedule/_base_schedule.py
@@ -6,8 +6,8 @@ from typing import Callable, Iterable
 
 import torch
 
+from colossalai.accelerator import get_accelerator
 from colossalai.logging import get_dist_logger
-from colossalai.utils import get_current_device
 
 
 class BaseSchedule(ABC):
@@ -29,12 +29,12 @@ class BaseSchedule(ABC):
     def _move_tensor(element):
         if torch.is_tensor(element):
             if not element.is_cuda:
-                return element.to(get_current_device()).detach()
+                return element.to(get_accelerator().get_current_device()).detach()
         return element
 
     def _move_to_device(self, data):
         if isinstance(data, torch.Tensor):
-            data = data.to(get_current_device())
+            data = data.to(get_accelerator().get_current_device())
         elif isinstance(data, (list, tuple)):
             data_to_return = []
             for element in data:
diff --git a/colossalai/legacy/engine/schedule/_pipeline_schedule.py b/colossalai/legacy/engine/schedule/_pipeline_schedule.py
index 5fd5602e7..4a23853c1 100644
--- a/colossalai/legacy/engine/schedule/_pipeline_schedule.py
+++ b/colossalai/legacy/engine/schedule/_pipeline_schedule.py
@@ -7,12 +7,12 @@ from typing import Callable, List, Tuple, Union
 import torch.cuda
 
 import colossalai.legacy.communication as comm
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.amp.naive_amp import NaiveAMPModel
 from colossalai.legacy.context.parallel_mode import ParallelMode
 from colossalai.legacy.core import global_context as gpc
 from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank
 from colossalai.logging import get_dist_logger
-from colossalai.utils.device import get_current_device
 
 from ._base_schedule import BaseSchedule
 
@@ -352,7 +352,7 @@ class PipelineSchedule(BaseSchedule):
             output_objs = []
         return_tensors = []
         if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
-            accum_loss = torch.zeros(1, device=get_current_device())
+            accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
         else:
             accum_loss = None
         # Used for tensor meta information communication
@@ -584,7 +584,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
         if not forward_only:
             output_obj_grads = [[] for _ in range(len(model))]
         if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
-            accum_loss = torch.zeros(1, device=get_current_device())
+            accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
         else:
             accum_loss = None
 
diff --git a/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py
index 4cd7e47c3..6e7760218 100644
--- a/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py
+++ b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py
@@ -6,10 +6,10 @@ from typing import Iterable, Tuple
 import torch.cuda
 
 import colossalai.legacy.communication.p2p_v2 as comm
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.context.parallel_mode import ParallelMode
 from colossalai.legacy.core import global_context as gpc
 from colossalai.legacy.engine import Engine
-from colossalai.utils.device import get_current_device
 
 from ._pipeline_schedule import PipelineSchedule
 
@@ -99,7 +99,7 @@ class PipelineScheduleV2(PipelineSchedule):
             output_objs = []
         return_tensors = []
         if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
-            accum_loss = torch.zeros(1, device=get_current_device())
+            accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
         else:
             accum_loss = None
 
diff --git a/colossalai/legacy/initialize.py b/colossalai/legacy/initialize.py
index 4035bd6b5..d99a7d3f0 100644
--- a/colossalai/legacy/initialize.py
+++ b/colossalai/legacy/initialize.py
@@ -15,6 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler
 from torch.optim.optimizer import Optimizer
 from torch.utils.data import DataLoader
 
+from colossalai.accelerator import get_accelerator
 from colossalai.context import Config, ConfigException
 from colossalai.interface import OptimizerWrapper
 from colossalai.legacy.amp import AMP_TYPE, convert_to_amp
@@ -34,7 +35,6 @@ from colossalai.legacy.utils import is_using_ddp, is_using_pp, is_using_sequence
 from colossalai.legacy.zero import ShardedOptimizerV2, convert_to_zero_v2
 from colossalai.legacy.zero.gemini.ophooks import BaseOpHook
 from colossalai.logging import get_dist_logger
-from colossalai.utils import get_current_device
 
 
 def get_default_parser():
@@ -309,9 +309,9 @@ def initialize(
     else:
         if isinstance(model, nn.Module):
             # first sync model across dp ranks
-            model.to(get_current_device())
+            model.to(get_accelerator().get_current_device())
         elif isinstance(model, Callable):
-            model = model().to(get_current_device())
+            model = model().to(get_accelerator().get_current_device())
 
         # optimizer maybe a optimizer_cls
         if isinstance(optimizer, Callable):
diff --git a/colossalai/legacy/nn/layer/colossalai_layer/embedding.py b/colossalai/legacy/nn/layer/colossalai_layer/embedding.py
index e1db0fe98..aa661664f 100644
--- a/colossalai/legacy/nn/layer/colossalai_layer/embedding.py
+++ b/colossalai/legacy/nn/layer/colossalai_layer/embedding.py
@@ -3,8 +3,8 @@ from typing import Callable
 
 from torch import dtype, nn
 
+from colossalai.accelerator import get_accelerator
 from colossalai.nn import init
-from colossalai.utils import get_current_device
 
 from ..parallel_1d import Embedding1D, PatchEmbedding1D, VocabParallelEmbedding1D
 from ..parallel_2d import Embedding2D, PatchEmbedding2D, VocabParallelEmbedding2D
@@ -83,7 +83,7 @@ class Embedding(ColossalaiModule):
             embed = (
                 nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args, **kwargs)
                 .to(dtype)
-                .to(get_current_device())
+                .to(get_accelerator().get_current_device())
             )
             weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
         elif num_embeddings <= vocab_parallel_limit:
diff --git a/colossalai/legacy/nn/layer/colossalai_layer/normalization.py b/colossalai/legacy/nn/layer/colossalai_layer/normalization.py
index f8e317e72..58842f481 100644
--- a/colossalai/legacy/nn/layer/colossalai_layer/normalization.py
+++ b/colossalai/legacy/nn/layer/colossalai_layer/normalization.py
@@ -1,6 +1,6 @@
 from torch import nn
 
-from colossalai.utils import get_current_device
+from colossalai.accelerator import get_accelerator
 
 from ..parallel_1d import LayerNorm1D
 from ..parallel_2d import LayerNorm2D
@@ -36,7 +36,7 @@ class LayerNorm(ColossalaiModule):
     def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None:
         tensor_parallel = get_tensor_parallel_mode()
         if tensor_parallel is None:
-            norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device())
+            norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_accelerator().get_current_device())
         else:
             norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
         super().__init__(norm)
diff --git a/colossalai/legacy/nn/layer/parallel_1d/layers.py b/colossalai/legacy/nn/layer/parallel_1d/layers.py
index b6ec5347f..36cb09d32 100644
--- a/colossalai/legacy/nn/layer/parallel_1d/layers.py
+++ b/colossalai/legacy/nn/layer/parallel_1d/layers.py
@@ -10,6 +10,7 @@ import torch.nn.functional as F
 from torch import Tensor
 from torch.nn.parameter import Parameter
 
+from colossalai.accelerator import get_accelerator
 from colossalai.kernel import LayerNorm
 from colossalai.legacy.communication import broadcast
 from colossalai.legacy.context import ParallelMode, seed
@@ -22,7 +23,6 @@ from colossalai.legacy.utils.checkpointing import (
     partition_tensor_parallel_state_dict,
 )
 from colossalai.nn import init as init
-from colossalai.utils.device import get_current_device
 
 from ..base_layer import ParallelLayer
 from ..colossalai_layer._utils import ColossalaiModule
@@ -221,7 +221,7 @@ class Classifier1D(ParallelLayer):
 
         # Parameters.
         # Initialize weight.
-        factory_kwargs = {"device": get_current_device(), "dtype": dtype}
+        factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
         if weight is not None:
             self.weight = weight
             self.has_weight = False
@@ -357,7 +357,7 @@ class VocabParallelClassifier1D(ParallelLayer):
 
         # Parameters.
         # Initialize weight.
-        factory_kwargs = {"device": get_current_device(), "dtype": dtype}
+        factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
         if weight is not None:
             self.weight = weight
             self.has_weight = False
@@ -499,7 +499,7 @@ class Linear1D_Col(ParallelLayer):
 
         # Parameters.
         # Initialize weight.
-        factory_kwargs = {"device": get_current_device(), "dtype": dtype}
+        factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
         self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
 
         if bias:
@@ -638,7 +638,7 @@ class Linear1D_Row(ParallelLayer):
 
         # Parameters.
         # Initialize weight.
-        factory_kwargs = {"device": get_current_device(), "dtype": dtype}
+        factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
         self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
 
         if self.stream_chunk_num > 1:
@@ -802,7 +802,9 @@ class Embedding1D(ParallelLayer):
         self.embed_kwargs = kwargs
 
         self.weight = Parameter(
-            torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)
+            torch.empty(
+                (num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
+            )
         )
 
         self.reset_parameters(weight_initializer)
@@ -912,7 +914,11 @@ class VocabParallelEmbedding1D(ParallelLayer):
         self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
 
         self.weight = Parameter(
-            torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype)
+            torch.empty(
+                (self.num_embeddings_per_partition, self.embed_dim),
+                device=get_accelerator().get_current_device(),
+                dtype=dtype,
+            )
         )
 
         self.reset_parameters(weight_initializer)
diff --git a/colossalai/legacy/nn/layer/parallel_2d/_operation.py b/colossalai/legacy/nn/layer/parallel_2d/_operation.py
index f1eff7128..f67ee2e60 100644
--- a/colossalai/legacy/nn/layer/parallel_2d/_operation.py
+++ b/colossalai/legacy/nn/layer/parallel_2d/_operation.py
@@ -5,10 +5,10 @@ import torch.distributed as dist
 from torch import Tensor
 from torch.cuda.amp import custom_bwd, custom_fwd
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter
 from colossalai.legacy.context.parallel_mode import ParallelMode
 from colossalai.legacy.core import global_context as gpc
-from colossalai.utils import get_current_device
 
 
 def matmul_2d(
@@ -250,7 +250,7 @@ class Matmul_AB_2D(torch.autograd.Function):
         B_shape = B.shape
         B = B.reshape((-1, B_shape[-1]))
         C_shape = (A.shape[0], B.shape[-1])
-        C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
+        C = torch.zeros(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
 
         # use circular buffer to store the communication tensor
         # 2 is enough for all cases
@@ -399,7 +399,7 @@ class Matmul_ABT_2D(torch.autograd.Function):
         B_shape = B.shape
         B = B.reshape((-1, B_shape[-1]))
         C_shape = (A.shape[0], B.shape[0])
-        C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
+        C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
 
         # use circular buffer to store the communication tensor
         # 2 is enough for all cases
@@ -556,7 +556,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
         B_shape = B.shape
         B = B.reshape((-1, B_shape[-1]))
         C_shape = (A.shape[-1], B.shape[-1])
-        C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
+        C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
 
         # use circular buffer to store the communication tensor
         # 2 is enough for all cases
diff --git a/colossalai/legacy/nn/layer/parallel_2d/layers.py b/colossalai/legacy/nn/layer/parallel_2d/layers.py
index f81c5334a..4987afa18 100644
--- a/colossalai/legacy/nn/layer/parallel_2d/layers.py
+++ b/colossalai/legacy/nn/layer/parallel_2d/layers.py
@@ -8,6 +8,7 @@ import torch.nn.functional as F
 from torch import Tensor
 from torch.nn import Parameter
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.communication import broadcast
 from colossalai.legacy.context import ParallelMode, seed
 from colossalai.legacy.core import global_context as gpc
@@ -18,7 +19,6 @@ from colossalai.legacy.utils.checkpointing import (
     partition_tensor_parallel_state_dict,
 )
 from colossalai.nn import init as init
-from colossalai.utils.device import get_current_device
 
 from ..base_layer import ParallelLayer
 from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
@@ -82,7 +82,7 @@ class Linear2D(ParallelLayer):
         self.hidden_size_per_partition = divide(self.out_features, self.summa_dim)
 
         # create weight, shape: [k/q, h/q]
-        factory_kwargs = {"device": get_current_device(), "dtype": dtype}
+        factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
         self.weight = Parameter(
             torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)
         )
@@ -259,7 +259,7 @@ class LayerNorm2D(ParallelLayer):
         self.partitioned_partition = divide(normalized_shape, self.summa_dim**2)
 
         # create parameters
-        factory_kwargs = {"device": get_current_device(), "dtype": dtype}
+        factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
 
         self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
         if bias:
@@ -438,18 +438,24 @@ class PatchEmbedding2D(ParallelLayer):
             self.weight = Parameter(
                 torch.empty(
                     (self.embed_size_per_partition, in_chans, *self.patch_size),
-                    device=get_current_device(),
+                    device=get_accelerator().get_current_device(),
                     dtype=dtype,
                 )
             )
-            self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype))
+            self.bias = Parameter(
+                torch.empty(self.embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype)
+            )
 
             self.cls_token = Parameter(
-                torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype)
+                torch.zeros(
+                    (1, 1, self.embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
+                )
             )
             self.pos_embed = Parameter(
                 torch.zeros(
-                    (1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype
+                    (1, self.num_patches + 1, self.embed_size_per_partition),
+                    device=get_accelerator().get_current_device(),
+                    dtype=dtype,
                 )
             )
 
@@ -619,7 +625,9 @@ class Embedding2D(ParallelLayer):
         self.embed_kwargs = kwargs
 
         self.weight = Parameter(
-            torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)
+            torch.empty(
+                (num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
+            )
         )
 
         self.reset_parameters(weight_initializer)
@@ -758,7 +766,7 @@ class VocabParallelEmbedding2D(ParallelLayer):
         self.weight = Parameter(
             torch.empty(
                 (self.num_embeddings_per_partition, self.embed_dim_per_partition),
-                device=get_current_device(),
+                device=get_accelerator().get_current_device(),
                 dtype=dtype,
             )
         )
@@ -895,11 +903,18 @@ class Classifier2D(ParallelLayer):
             self.has_weight = False
         else:
             self.weight = Parameter(
-                torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype)
+                torch.empty(
+                    self.num_classes,
+                    self.input_size_per_partition,
+                    device=get_accelerator().get_current_device(),
+                    dtype=dtype,
+                )
             )
             self.has_weight = True
         if bias:
-            self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
+            self.bias = Parameter(
+                torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype)
+            )
         else:
             self.bias = None
 
@@ -1052,7 +1067,7 @@ class VocabParallelClassifier2D(ParallelLayer):
         self.output_size_per_partition = divide(num_classes, self.summa_dim)
 
         # create weight, shape: [k/q, h/q]
-        factory_kwargs = {"device": get_current_device(), "dtype": dtype}
+        factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
         if weight is not None:
             self.weight = weight
             self.has_weight = False
diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py b/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py
index 50900c135..43328bd03 100644
--- a/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py
+++ b/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py
@@ -5,10 +5,10 @@ import torch.distributed as dist
 from torch import Tensor
 from torch.cuda.amp import custom_bwd, custom_fwd
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter
 from colossalai.legacy.context.parallel_mode import ParallelMode
 from colossalai.legacy.core import global_context as gpc
-from colossalai.utils import get_current_device
 
 
 def get_parallel_group(parallel_mode: ParallelMode):
@@ -205,7 +205,7 @@ class Matmul_AB_2p5D(torch.autograd.Function):
         B_shape = B.shape
         B = B.reshape((-1, B_shape[-1]))
         C_shape = (A.shape[0], B.shape[-1])
-        C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
+        C = torch.zeros(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
 
         # use circular buffer to store the communication tensor
         # 2 is enough for all cases
@@ -362,7 +362,7 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
         B_shape = B.shape
         B = B.reshape((-1, B_shape[-1]))
         C_shape = (A.shape[0], B.shape[0])
-        C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
+        C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
 
         # use circular buffer to store the communication tensor
         # 2 is enough for all cases
@@ -527,7 +527,7 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
         B_shape = B.shape
         B = B.reshape((-1, B_shape[-1]))
         C_shape = (A.shape[-1], B.shape[-1])
-        C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
+        C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
 
         # use circular buffer to store the communication tensor
         # 2 is enough for all cases
@@ -661,7 +661,9 @@ class _Add_Bias_2p5D(torch.autograd.Function):
         if row_rank == 0:
             bias_temp = bias.clone()
         else:
-            bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device())
+            bias_temp = torch.zeros(
+                output_size_per_partition, dtype=bias.dtype, device=get_accelerator().get_current_device()
+            )
         src_rank = (
             col_rank
             + dep_rank * tesseract_dim**2
@@ -984,7 +986,7 @@ class SplitFirst(torch.autograd.Function):
     @custom_bwd
     def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
         grad_shape = (ctx.batch_size,) + output_grad.shape[1:]
-        grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device())
+        grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_accelerator().get_current_device())
         dist.all_gather(
             list(grad.chunk(ctx.tesseract_dim, dim=0)), output_grad.contiguous(), group=gpc.get_group(ctx.para_mode)
         )
diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/layers.py b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py
index b451a4031..d9410f1cb 100644
--- a/colossalai/legacy/nn/layer/parallel_2p5d/layers.py
+++ b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py
@@ -8,6 +8,7 @@ import torch.nn.functional as F
 from torch import Tensor
 from torch.nn import Parameter
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.communication import broadcast
 from colossalai.legacy.context import ParallelMode, seed
 from colossalai.legacy.core import global_context as gpc
@@ -19,7 +20,6 @@ from colossalai.legacy.utils.checkpointing import (
     partition_tensor_parallel_state_dict,
 )
 from colossalai.nn import init as init
-from colossalai.utils.device import get_current_device
 
 from ..base_layer import ParallelLayer
 from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
@@ -84,7 +84,7 @@ class Linear2p5D(ParallelLayer):
         self.hidden_size_per_partition = divide(out_features, self.tesseract_dim)
 
         # create weight, shape: [k/q, h/q]
-        factory_kwargs = {"device": get_current_device(), "dtype": dtype}
+        factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
         self.weight = Parameter(
             torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)
         )
@@ -272,7 +272,7 @@ class LayerNorm2p5D(ParallelLayer):
         self.partitioned_partition = divide(normalized_shape, self.tesseract_dim)  # *
 
         # create parameters
-        factory_kwargs = {"device": get_current_device(), "dtype": dtype}
+        factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
 
         self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
         if bias:
@@ -451,18 +451,24 @@ class PatchEmbedding2p5D(ParallelLayer):
             self.weight = Parameter(
                 torch.empty(
                     (self.embed_size_per_partition, in_chans, *self.patch_size),
-                    device=get_current_device(),
+                    device=get_accelerator().get_current_device(),
                     dtype=dtype,
                 )
             )
-            self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype))
+            self.bias = Parameter(
+                torch.empty(self.embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype)
+            )
 
             self.cls_token = Parameter(
-                torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype)
+                torch.zeros(
+                    (1, 1, self.embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
+                )
             )
             self.pos_embed = Parameter(
                 torch.zeros(
-                    (1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype
+                    (1, self.num_patches + 1, self.embed_size_per_partition),
+                    device=get_accelerator().get_current_device(),
+                    dtype=dtype,
                 )
             )
 
@@ -632,7 +638,9 @@ class Embedding2p5D(ParallelLayer):
         self.embed_kwargs = kwargs
 
         self.weight = Parameter(
-            torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)
+            torch.empty(
+                (num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
+            )
         )
 
         self.reset_parameters(weight_initializer)
@@ -772,7 +780,7 @@ class VocabParallelEmbedding2p5D(ParallelLayer):
         self.weight = Parameter(
             torch.empty(
                 (self.num_embeddings_per_partition, self.embed_dim_per_partition),
-                device=get_current_device(),
+                device=get_accelerator().get_current_device(),
                 dtype=dtype,
             )
         )
@@ -910,11 +918,18 @@ class Classifier2p5D(ParallelLayer):
             self.has_weight = False
         else:
             self.weight = Parameter(
-                torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype)
+                torch.empty(
+                    self.num_classes,
+                    self.input_size_per_partition,
+                    device=get_accelerator().get_current_device(),
+                    dtype=dtype,
+                )
             )
             self.has_weight = True
         if bias:
-            self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
+            self.bias = Parameter(
+                torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype)
+            )
         else:
             self.bias = None
 
@@ -1068,7 +1083,7 @@ class VocabParallelClassifier2p5D(ParallelLayer):
         self.hidden_size_per_partition = divide(num_classes, self.tesseract_dim)
 
         # create weight, shape: [k/q, h/q]
-        factory_kwargs = {"device": get_current_device(), "dtype": dtype}
+        factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
         if weight is not None:
             self.weight = weight
             self.has_weight = False
diff --git a/colossalai/legacy/nn/layer/parallel_3d/layers.py b/colossalai/legacy/nn/layer/parallel_3d/layers.py
index 16e515f87..bb01ec851 100644
--- a/colossalai/legacy/nn/layer/parallel_3d/layers.py
+++ b/colossalai/legacy/nn/layer/parallel_3d/layers.py
@@ -8,6 +8,7 @@ import torch.nn.functional as F
 from torch import Tensor
 from torch.nn import Parameter
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.communication import all_reduce, broadcast
 from colossalai.legacy.constants import (
     INPUT_GROUP_3D,
@@ -27,7 +28,6 @@ from colossalai.legacy.utils.checkpointing import (
     partition_tensor_parallel_state_dict,
 )
 from colossalai.nn import init as init
-from colossalai.utils.device import get_current_device
 
 from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
 from ._operation import (
@@ -69,11 +69,13 @@ class LayerNorm3D(ParallelLayer):
         self.normalized_shape_per_partition = divide(normalized_shape, self.depth)
 
         self.weight = Parameter(
-            torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)
+            torch.ones(self.normalized_shape_per_partition, device=get_accelerator().get_current_device(), dtype=dtype)
         )
         if bias:
             self.bias = Parameter(
-                torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)
+                torch.zeros(
+                    self.normalized_shape_per_partition, device=get_accelerator().get_current_device(), dtype=dtype
+                )
             )
         else:
             self.bias = None
@@ -202,13 +204,15 @@ class Linear3D(ParallelLayer):
             torch.empty(
                 self.in_features_per_partition,
                 self.out_features_per_partition,
-                device=get_current_device(),
+                device=get_accelerator().get_current_device(),
                 dtype=dtype,
             )
         )
         if bias:
             self.bias = Parameter(
-                torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)
+                torch.zeros(
+                    self.bias_features_per_partition, device=get_accelerator().get_current_device(), dtype=dtype
+                )
             )
         else:
             self.bias = None
@@ -380,11 +384,18 @@ class Classifier3D(ParallelLayer):
             self.has_weight = False
         else:
             self.weight = Parameter(
-                torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype)
+                torch.empty(
+                    self.num_classes,
+                    self.in_features_per_partition,
+                    device=get_accelerator().get_current_device(),
+                    dtype=dtype,
+                )
             )
             self.has_weight = True
         if bias:
-            self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
+            self.bias = Parameter(
+                torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype)
+            )
         else:
             self.bias = None
 
@@ -523,14 +534,16 @@ class VocabParallelClassifier3D(ParallelLayer):
                 torch.empty(
                     self.out_features_per_partition,
                     self.in_features_per_partition,
-                    device=get_current_device(),
+                    device=get_accelerator().get_current_device(),
                     dtype=dtype,
                 )
             )
             self.has_weight = True
         if bias:
             self.bias = Parameter(
-                torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)
+                torch.zeros(
+                    self.bias_features_per_partition, device=get_accelerator().get_current_device(), dtype=dtype
+                )
             )
         else:
             self.bias = None
@@ -705,16 +718,24 @@ class PatchEmbedding3D(ParallelLayer):
 
         self.weight = nn.Parameter(
             torch.empty(
-                (embed_size_per_partition, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype
+                (embed_size_per_partition, in_chans, *self.patch_size),
+                device=get_accelerator().get_current_device(),
+                dtype=dtype,
             )
         )
-        self.bias = nn.Parameter(torch.empty(embed_size_per_partition, device=get_current_device(), dtype=dtype))
+        self.bias = nn.Parameter(
+            torch.empty(embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype)
+        )
 
         self.cls_token = nn.Parameter(
-            torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)
+            torch.zeros((1, 1, embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype)
         )
         self.pos_embed = nn.Parameter(
-            torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)
+            torch.zeros(
+                (1, self.num_patches + 1, embed_size_per_partition),
+                device=get_accelerator().get_current_device(),
+                dtype=dtype,
+            )
         )
 
         self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
@@ -880,7 +901,9 @@ class Embedding3D(ParallelLayer):
         self.embed_kwargs = kwargs
 
         self.weight = nn.Parameter(
-            torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)
+            torch.empty(
+                (num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
+            )
         )
 
         self.reset_parameters(weight_initializer)
@@ -1019,7 +1042,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
         self.weight = Parameter(
             torch.empty(
                 (self.num_embeddings_per_partition, self.embed_dim_per_partition),
-                device=get_current_device(),
+                device=get_accelerator().get_current_device(),
                 dtype=dtype,
             )
         )
diff --git a/colossalai/legacy/nn/layer/parallel_sequence/_operation.py b/colossalai/legacy/nn/layer/parallel_sequence/_operation.py
index 24d5499e3..4e9bf364d 100644
--- a/colossalai/legacy/nn/layer/parallel_sequence/_operation.py
+++ b/colossalai/legacy/nn/layer/parallel_sequence/_operation.py
@@ -5,11 +5,11 @@ import torch
 from torch import distributed as dist
 from torch.cuda.amp import custom_bwd, custom_fwd
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.communication import ring_forward
 from colossalai.legacy.context.parallel_mode import ParallelMode
 from colossalai.legacy.core import global_context as gpc
 from colossalai.legacy.nn.layer.parallel_sequence._utils import _calc_current_device_range, _calc_incoming_device_range
-from colossalai.utils import get_current_device
 
 
 class RingQK(torch.autograd.Function):
@@ -30,7 +30,7 @@ class RingQK(torch.autograd.Function):
             sub_seq_length,
             sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE),
             dtype=sub_q.dtype,
-            device=get_current_device(),
+            device=get_accelerator().get_current_device(),
         )
 
         # compute local QK^T
@@ -71,7 +71,7 @@ class RingQK(torch.autograd.Function):
         grad_q = torch.zeros_like(
             sub_q,
             dtype=sub_q.dtype,
-            device=get_current_device(),
+            device=get_accelerator().get_current_device(),
         )
 
         # compute with local sub_k
@@ -105,7 +105,7 @@ class RingAV(torch.autograd.Function):
             batch_size * num_attention_heads,
             sub_seq_length,
             attention_head_size,
-            device=get_current_device(),
+            device=get_accelerator().get_current_device(),
             dtype=attention_score.dtype,
         )
 
@@ -142,7 +142,9 @@ class RingAV(torch.autograd.Function):
         grad_v /= local_world_size
 
         # calculate gradient for attention score
-        grad_attention_score = torch.zeros_like(attention_scores, dtype=grad_output.dtype, device=get_current_device())
+        grad_attention_score = torch.zeros_like(
+            attention_scores, dtype=grad_output.dtype, device=get_accelerator().get_current_device()
+        )
 
         # compute with local sub_k
         grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1))
diff --git a/colossalai/legacy/nn/layer/vanilla/layers.py b/colossalai/legacy/nn/layer/vanilla/layers.py
index 590ad5ff6..3a1c2e57b 100644
--- a/colossalai/legacy/nn/layer/vanilla/layers.py
+++ b/colossalai/legacy/nn/layer/vanilla/layers.py
@@ -7,10 +7,10 @@ from torch import Tensor
 from torch import nn as nn
 from torch.nn.parameter import Parameter
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.context import seed
 from colossalai.legacy.registry import LAYERS
 from colossalai.nn import init as init
-from colossalai.utils.device import get_current_device
 
 from ..utils import to_2tuple
 
@@ -173,12 +173,18 @@ class VanillaPatchEmbedding(nn.Module):
         self.flatten = flatten
 
         self.weight = nn.Parameter(
-            torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype)
+            torch.empty(
+                (embed_size, in_chans, *self.patch_size), device=get_accelerator().get_current_device(), dtype=dtype
+            )
+        )
+        self.bias = nn.Parameter(torch.empty(embed_size, device=get_accelerator().get_current_device(), dtype=dtype))
+        self.cls_token = nn.Parameter(
+            torch.zeros((1, 1, embed_size), device=get_accelerator().get_current_device(), dtype=dtype)
         )
-        self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype))
-        self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype))
         self.pos_embed = nn.Parameter(
-            torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype)
+            torch.zeros(
+                (1, self.num_patches + 1, embed_size), device=get_accelerator().get_current_device(), dtype=dtype
+            )
         )
 
         self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
@@ -242,11 +248,15 @@ class VanillaClassifier(nn.Module):
             self.has_weight = False
         else:
             self.weight = nn.Parameter(
-                torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype)
+                torch.empty(
+                    self.num_classes, self.in_features, device=get_accelerator().get_current_device(), dtype=dtype
+                )
             )
             self.has_weight = True
         if bias:
-            self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
+            self.bias = nn.Parameter(
+                torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype)
+            )
         else:
             self.bias = None
 
@@ -287,7 +297,7 @@ class VanillaLayerNorm(nn.Module):
         self.normalized_shape = (normalized_shape,)
         self.variance_epsilon = eps
 
-        factory_kwargs = {"device": get_current_device(), "dtype": dtype}
+        factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
 
         self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs))
         if bias:
@@ -333,7 +343,7 @@ class VanillaLinear(nn.Module):
         self.in_features = in_features
         self.out_features = out_features
         self.skip_bias_add = skip_bias_add
-        factory_kwargs = {"device": get_current_device(), "dtype": dtype}
+        factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
         self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
         if bias:
             self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
diff --git a/colossalai/legacy/nn/loss/loss_2d.py b/colossalai/legacy/nn/loss/loss_2d.py
index 44f39a6db..474fd4a2c 100644
--- a/colossalai/legacy/nn/loss/loss_2d.py
+++ b/colossalai/legacy/nn/loss/loss_2d.py
@@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd
 from torch.nn.functional import cross_entropy
 from torch.nn.modules.loss import _Loss
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.context import ParallelMode
 from colossalai.legacy.core import global_context as gpc
 from colossalai.legacy.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
 from colossalai.legacy.nn.layer.parallel_2d._utils import assert_summa_initialization
 from colossalai.legacy.registry import LOSSES
-from colossalai.utils import get_current_device
 
 
 @LOSSES.register_module
@@ -118,7 +118,7 @@ class _VocabParallelCrossEntropy2D(torch.autograd.Function):
         grad_2d = grad_input.view(-1, partition_vocab_size)
 
         # Add the gradient from matching classes.
-        arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device())
+        arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device())
         grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float()
 
         # Finally elementwise multiplication with the output gradients.
diff --git a/colossalai/legacy/nn/loss/loss_2p5d.py b/colossalai/legacy/nn/loss/loss_2p5d.py
index c57bf26e9..b423ab3d8 100644
--- a/colossalai/legacy/nn/loss/loss_2p5d.py
+++ b/colossalai/legacy/nn/loss/loss_2p5d.py
@@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd
 from torch.nn.functional import cross_entropy
 from torch.nn.modules.loss import _Loss
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.context import ParallelMode
 from colossalai.legacy.core import global_context as gpc
 from colossalai.legacy.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
 from colossalai.legacy.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
 from colossalai.legacy.registry import LOSSES
-from colossalai.utils import get_current_device
 
 
 @LOSSES.register_module
@@ -112,7 +112,7 @@ class _VocabParallelCrossEntropy2p5D(torch.autograd.Function):
         grad_2d = grad_input.view(-1, partition_vocab_size)
 
         # Add the gradient from matching classes.
-        arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device())
+        arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device())
         grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float()
 
         # Finally elementwise multiplication with the output gradients.
diff --git a/colossalai/legacy/nn/loss/loss_3d.py b/colossalai/legacy/nn/loss/loss_3d.py
index 988317cae..de6a674d6 100644
--- a/colossalai/legacy/nn/loss/loss_3d.py
+++ b/colossalai/legacy/nn/loss/loss_3d.py
@@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd
 from torch.nn.functional import cross_entropy
 from torch.nn.modules.loss import _Loss
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
 from colossalai.legacy.core import global_context as gpc
 from colossalai.legacy.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
 from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
 from colossalai.legacy.registry import LOSSES
-from colossalai.utils import get_current_device
 
 
 @LOSSES.register_module
@@ -80,7 +80,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function):
         target_mask = (targets < vocab_start) | (targets > vocab_end)
         masked_target = targets.clone() - vocab_start
         masked_target[target_mask] = 0
-        arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_current_device())
+        arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_accelerator().get_current_device())
         predicted_logits = logits[arange_1d, masked_target]
         predicted_logits = predicted_logits.clone().contiguous().view_as(targets)
         predicted_logits[target_mask] = 0.0
@@ -110,7 +110,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function):
         grad_2d = input_grad.view(-1, partition_vocab_size)
 
         # Add the gradient from matching classes.
-        arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device())
+        arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device())
         grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float()
         input_grad.mul_(output_grad.unsqueeze(dim=-1))
 
diff --git a/colossalai/legacy/trainer/hooks/_metric_hook.py b/colossalai/legacy/trainer/hooks/_metric_hook.py
index 35a7f0a15..0e6731db5 100644
--- a/colossalai/legacy/trainer/hooks/_metric_hook.py
+++ b/colossalai/legacy/trainer/hooks/_metric_hook.py
@@ -7,12 +7,12 @@ from typing import Callable
 import torch
 import torch.distributed as dist
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.communication import all_reduce
 from colossalai.legacy.context import ParallelMode
 from colossalai.legacy.core import global_context as gpc
 from colossalai.legacy.registry import HOOKS
 from colossalai.legacy.utils import is_no_pp_or_last_stage
-from colossalai.utils import get_current_device
 
 from ._base_hook import BaseHook
 from ._commons_ import _format_number
@@ -82,8 +82,8 @@ class LossMetric(Metric):
 
     def __init__(self, epoch_only):
         super().__init__(epoch_only=epoch_only)
-        self.last_step_loss = torch.zeros(1, device=get_current_device())
-        self.accum_loss = torch.zeros(1, device=get_current_device())
+        self.last_step_loss = torch.zeros(1, device=get_accelerator().get_current_device())
+        self.accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
         self.count = 0
 
     def reset(self) -> None:
@@ -164,10 +164,10 @@ class AccuracyMetric(Metric):
     def __init__(self, epoch_only: bool, accuracy_func: Callable):
         super().__init__(epoch_only=epoch_only)
         self.acc = accuracy_func
-        self.last_step_sum = torch.zeros(1, device=get_current_device())
-        self.last_step_correct = torch.zeros(1, device=get_current_device())
-        self.accumulated_sum = torch.zeros(1, device=get_current_device())
-        self.accumulated_correct = torch.zeros(1, device=get_current_device())
+        self.last_step_sum = torch.zeros(1, device=get_accelerator().get_current_device())
+        self.last_step_correct = torch.zeros(1, device=get_accelerator().get_current_device())
+        self.accumulated_sum = torch.zeros(1, device=get_accelerator().get_current_device())
+        self.accumulated_correct = torch.zeros(1, device=get_accelerator().get_current_device())
 
     def reset(self) -> None:
         self.last_step_sum.zero_()
@@ -320,10 +320,10 @@ class ThroughputMetric(Metric):
         super().__init__(epoch_only=epoch_only)
         self.ignored_steps = ignored_steps
         self.cur_steps = 0
-        self.accumulated_num_samples = torch.zeros(1, device=get_current_device())
-        self.accumulated_used_time = torch.zeros(1, device=get_current_device())
-        self.last_step_num_samples = torch.zeros(1, device=get_current_device())
-        self.last_step_used_time = torch.zeros(1, device=get_current_device())
+        self.accumulated_num_samples = torch.zeros(1, device=get_accelerator().get_current_device())
+        self.accumulated_used_time = torch.zeros(1, device=get_accelerator().get_current_device())
+        self.last_step_num_samples = torch.zeros(1, device=get_accelerator().get_current_device())
+        self.last_step_used_time = torch.zeros(1, device=get_accelerator().get_current_device())
         self._tflop_per_step = tflop_per_step
         self._use_local = use_local
 
diff --git a/colossalai/legacy/utils/activation_checkpoint.py b/colossalai/legacy/utils/activation_checkpoint.py
index 9a8051ae9..d1382cb1e 100644
--- a/colossalai/legacy/utils/activation_checkpoint.py
+++ b/colossalai/legacy/utils/activation_checkpoint.py
@@ -6,8 +6,8 @@ import weakref
 import torch
 from torch.utils.checkpoint import check_backward_validity, detach_variable
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.context.random import get_current_mode, get_states, set_mode, set_seed_states, sync_states
-from colossalai.utils.device import autocast, get_current_device
 
 
 def copy_to_device(obj, device):
@@ -33,7 +33,7 @@ class CheckpointFunction(torch.autograd.Function):
         check_backward_validity(args)
         ctx.run_function = run_function
         ctx.activation_offload = activation_offload
-        ctx.device = get_current_device()
+        ctx.device = get_accelerator().get_current_device()
 
         # preserve rng states
         ctx.fwd_cpu_rng_state = torch.get_rng_state()
@@ -110,7 +110,7 @@ class CheckpointFunction(torch.autograd.Function):
             inputs[idx] = tensors[i]
         detached_inputs = detach_variable(tuple(inputs))
         if ctx.had_autocast_in_fwd:
-            with torch.enable_grad(), autocast():
+            with torch.enable_grad(), get_accelerator().autocast()():
                 outputs = ctx.run_function(*detached_inputs)
         else:
             with torch.enable_grad():
@@ -226,7 +226,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
 
             # rerun forward, the inner_pack will store all the activations in storage
             if has_autocast_in_fwd:
-                with torch.enable_grad(), autocast(), torch.autograd.graph.saved_tensors_hooks(
+                with torch.enable_grad(), get_accelerator().autocast()(), torch.autograd.graph.saved_tensors_hooks(
                     inner_pack, inner_unpack
                 ):
                     _unused = function(*args)
@@ -245,7 +245,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
 
     # get device if we need to offload the activation
     if activation_offload:
-        device = get_current_device()
+        device = get_accelerator().get_current_device()
 
     # run function with pack and unpack as saved_tensors_hooks
     with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
diff --git a/colossalai/legacy/utils/memory.py b/colossalai/legacy/utils/memory.py
index 2f99a7d2f..cfb22d315 100644
--- a/colossalai/legacy/utils/memory.py
+++ b/colossalai/legacy/utils/memory.py
@@ -6,9 +6,9 @@ import torch
 import torch.distributed as dist
 from packaging import version
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.core import global_context as gpc
 from colossalai.logging import get_dist_logger
-from colossalai.utils import get_current_device
 
 _GLOBAL_CUDA_MEM_FRACTION = 1.0
 _GLOBAL_CPU_MEM_CAPACITY = -1
@@ -112,7 +112,10 @@ def colo_device_memory_capacity(device: torch.device) -> int:
         # In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory.
         return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node
     if device.type == "cuda":
-        return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION
+        return (
+            torch.cuda.get_device_properties(get_accelerator().get_current_device()).total_memory
+            * _GLOBAL_CUDA_MEM_FRACTION
+        )
 
 
 def colo_device_memory_used(device: torch.device) -> int:
@@ -153,7 +156,7 @@ def colo_set_process_memory_fraction(ratio: float) -> None:
         return
     global _GLOBAL_CUDA_MEM_FRACTION
     _GLOBAL_CUDA_MEM_FRACTION = ratio
-    torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device())
+    torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_accelerator().get_current_device())
 
 
 def colo_set_cpu_memory_capacity(size: int) -> None:
diff --git a/colossalai/legacy/utils/profiler/legacy/comm_profiler.py b/colossalai/legacy/utils/profiler/legacy/comm_profiler.py
index ad54b989f..a9e3ffe1a 100644
--- a/colossalai/legacy/utils/profiler/legacy/comm_profiler.py
+++ b/colossalai/legacy/utils/profiler/legacy/comm_profiler.py
@@ -8,7 +8,7 @@ import torch.distributed as dist
 from torch.autograd.profiler import profile
 from torch.distributed import ReduceOp
 
-from colossalai.utils import get_current_device
+from colossalai.accelerator import get_accelerator
 
 from .prof_utils import BaseProfiler, _format_bandwidth, _format_memory, _format_time
 
@@ -177,7 +177,7 @@ class CommProfiler(BaseProfiler):
 
             assert current_comm_event is not None, "dist op has not been found"
 
-            buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device())
+            buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_accelerator().get_current_device())
             torch_all_reduce(buffer, op=ReduceOp.MIN, group=group)
             current_comm_event.self_cuda_time = buffer.item()
 
diff --git a/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py b/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py
index e336717f4..b0360880e 100644
--- a/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py
+++ b/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py
@@ -3,7 +3,7 @@ import types
 from time import time
 from typing import List
 
-from colossalai.utils.device import get_current_device
+from colossalai.accelerator import get_accelerator
 
 from .stateful_tensor import StatefulTensor, TensorState
 from .tensor_placement_policy import TensorPlacementPolicy
@@ -69,7 +69,7 @@ class StatefulTensorMgr(object):
         # move COMPUTE tensors to CUDA
         self._cpu_gpu_move_volume += cuda_demand
         for t in move_to_cuda_tensor_list:
-            colo_model_data_tensor_move_inline(t, get_current_device())
+            colo_model_data_tensor_move_inline(t, get_accelerator().get_current_device())
 
     @property
     def cpu_gpu_move_volume(self):
diff --git a/colossalai/legacy/zero/gemini/tensor_placement_policy.py b/colossalai/legacy/zero/gemini/tensor_placement_policy.py
index 3aca80cfe..6fde91d4a 100644
--- a/colossalai/legacy/zero/gemini/tensor_placement_policy.py
+++ b/colossalai/legacy/zero/gemini/tensor_placement_policy.py
@@ -5,8 +5,8 @@ from typing import List, Optional, Type
 
 import torch
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.utils.memory import colo_device_memory_capacity
-from colossalai.utils import get_current_device
 from colossalai.zero.gemini.memory_tracer import MemStatsCollector
 
 from .stateful_tensor import StatefulTensor
@@ -38,7 +38,7 @@ class CPUTensorPlacementPolicy(TensorPlacementPolicy):
 class CUDATensorPlacementPolicy(TensorPlacementPolicy):
     def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
         assert torch.cuda.is_available(), "Cannot use CUDATensorPlacementPolicy when CUDA is not available"
-        super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector)
+        super().__init__(get_accelerator().get_current_device(), mem_stats_collector=mem_stats_collector)
 
     def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int:
         return 0, 0
@@ -78,7 +78,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
             int: the volume of memory that is evicted
         """
         start = time()
-        cuda_capacity = colo_device_memory_capacity(get_current_device())
+        cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
         used_cuda_model_data = StatefulTensor.GST_MGR.total_mem["cuda"]
         if warmup:
             # We designate a part of CUDA memory for model data in warmup iterations.
diff --git a/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py b/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py
index b9d3071a8..e5a35dea1 100644
--- a/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py
+++ b/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py
@@ -4,8 +4,8 @@ import torch
 import torch.distributed as dist
 from torch._utils import _flatten_dense_tensors as flatten
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor
-from colossalai.utils import get_current_device
 
 from .tensor_shard_strategy import TensorShardStrategy
 
@@ -30,9 +30,11 @@ class BucketTensorShardStrategy(TensorShardStrategy):
         rank = dist.get_rank(process_group)
         for i in range(world_size):
             if i == rank:
-                buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device()))
+                buffer_list.append(
+                    flatten([t.payload for t in tensor_list]).cuda(get_accelerator().get_current_device())
+                )
             else:
-                buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device()))
+                buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_accelerator().get_current_device()))
         dist.all_gather(buffer_list, buffer_list[rank], group=process_group)
         # Move to target device before splitting buffer
         # Ensure we utilize maximum PCIE bandwidth
diff --git a/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py b/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py
index ebaef774b..fb6ef534b 100644
--- a/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py
+++ b/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py
@@ -3,11 +3,11 @@ from typing import List, Optional
 import torch
 import torch.distributed as dist
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_tensor_move_inline
 from colossalai.legacy.zero.shard_utils import BaseShardStrategy
 from colossalai.legacy.zero.shard_utils.commons import get_shard
 from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor
-from colossalai.utils import get_current_device
 
 
 class TensorShardStrategy(BaseShardStrategy):
@@ -34,9 +34,9 @@ class TensorShardStrategy(BaseShardStrategy):
         if t.is_sharded:
             return
         if t.payload.device.type == "cuda":
-            assert t.payload.device == get_current_device(), (
+            assert t.payload.device == get_accelerator().get_current_device(), (
                 f"shard tensor on cuda device index {t.payload.device.index},"
-                f" but current cuda device is {get_current_device()}"
+                f" but current cuda device is {get_accelerator().get_current_device()}"
             )
         sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
         t.payload_reset(sharded_payload)
@@ -50,7 +50,9 @@ class TensorShardStrategy(BaseShardStrategy):
         world_size = dist.get_world_size(process_group)
         rank = dist.get_rank(process_group)
 
-        buffer = torch.empty(payload_numel * world_size, dtype=t.payload.dtype, device=get_current_device())
+        buffer = torch.empty(
+            payload_numel * world_size, dtype=t.payload.dtype, device=get_accelerator().get_current_device()
+        )
         buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0))
         buffer_list[rank].copy_(t.payload)
 
diff --git a/colossalai/legacy/zero/sharded_model/sharded_model_v2.py b/colossalai/legacy/zero/sharded_model/sharded_model_v2.py
index 85f2ac215..bb7744a80 100644
--- a/colossalai/legacy/zero/sharded_model/sharded_model_v2.py
+++ b/colossalai/legacy/zero/sharded_model/sharded_model_v2.py
@@ -10,6 +10,7 @@ import torch.nn as nn
 from torch.distributed import ProcessGroup
 from torch.nn.parameter import Parameter
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.context.parallel_mode import ParallelMode
 from colossalai.legacy.core import global_context as gpc
 from colossalai.legacy.utils.memory import colo_device_memory_capacity
@@ -22,7 +23,7 @@ from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_move_to_c
 from colossalai.legacy.zero.shard_utils import BaseShardStrategy
 from colossalai.legacy.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
 from colossalai.logging import get_dist_logger
-from colossalai.utils import disposable, get_current_device
+from colossalai.utils import disposable
 from colossalai.zero.gemini.memory_tracer import MemStatsCollector
 
 from ._utils import (
@@ -212,8 +213,12 @@ class ShardedModelV2(nn.Module):
             self.logger.error(f"dump memory tracer collected information to a {filename}", ranks=[0])
             if gpc.get_global_rank() == 0:
                 with open(filename, "w+") as f:
-                    f.write(f"cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n")
-                    f.write(f"cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n")
+                    f.write(
+                        f"cuda reserved {torch.cuda.memory_reserved(get_accelerator().get_current_device()) / 1e9} GB\n"
+                    )
+                    f.write(
+                        f"cuda max allocated {torch.cuda.max_memory_allocated(get_accelerator().get_current_device()) / 1e9} GB\n"
+                    )
                     f.write("CUDA model data (GB)\n")
                     f.write("\n")
                     f.write("CUDA non model data (GB)\n")
@@ -266,7 +271,8 @@ class ShardedModelV2(nn.Module):
             # model data is fixed in cuda during training.
             # cuda margin space can be used to store OS.
             self._cuda_margin_space = (
-                colo_device_memory_capacity(get_current_device()) - self._memstats_collector._memstats.max_overall_cuda
+                colo_device_memory_capacity(get_accelerator().get_current_device())
+                - self._memstats_collector._memstats.max_overall_cuda
             )
 
     @torch.no_grad()
diff --git a/colossalai/legacy/zero/sharded_model/zero_hook.py b/colossalai/legacy/zero/sharded_model/zero_hook.py
index 892e9f31d..332f44d53 100644
--- a/colossalai/legacy/zero/sharded_model/zero_hook.py
+++ b/colossalai/legacy/zero/sharded_model/zero_hook.py
@@ -3,13 +3,13 @@ from typing import Optional
 import torch
 import torch.distributed as dist
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.registry import OPHOOKS
 from colossalai.legacy.zero.gemini.ophooks import BaseOpHook
 from colossalai.legacy.zero.gemini.stateful_tensor import TensorState
 from colossalai.legacy.zero.gemini.stateful_tensor_mgr import StatefulTensorMgr
 from colossalai.legacy.zero.shard_utils import BaseShardStrategy
 from colossalai.logging import get_dist_logger
-from colossalai.utils import get_current_device
 from colossalai.zero.gemini.memory_tracer import MemStatsCollector
 
 
@@ -33,7 +33,7 @@ class ZeroHook(BaseOpHook):
         self.process_group = process_group
 
         # NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
-        self.computing_device = get_current_device()
+        self.computing_device = get_accelerator().get_current_device()
 
         self._memstarts_collector = memstarts_collector
         self._stateful_tensor_mgr = stateful_tensor_mgr
diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py
index c5bb50862..f5815d05d 100644
--- a/colossalai/moe/routers.py
+++ b/colossalai/moe/routers.py
@@ -8,9 +8,9 @@ import torch.nn as nn
 import torch.nn.functional as F
 from torch.distributed import ProcessGroup
 
+from colossalai.accelerator import get_accelerator
 from colossalai.moe._operation import moe_cumsum
 from colossalai.moe.manager import MOE_MANAGER
-from colossalai.utils import get_current_device
 
 
 class MoeRouter(nn.Module, ABC):
@@ -24,14 +24,16 @@ class MoeRouter(nn.Module, ABC):
         drop_tks (bool, optional): Whether drops tokens in evaluation
     """
 
-    def __init__(self,
-                 k_value: int,
-                 capacity_factor_train: float,
-                 capacity_factor_eval: float,
-                 min_capacity: int,
-                 noisy_func: Optional[Callable] = None,
-                 drop_tks: bool = True,
-                 use_kernel: bool = False):
+    def __init__(
+        self,
+        k_value: int,
+        capacity_factor_train: float,
+        capacity_factor_eval: float,
+        min_capacity: int,
+        noisy_func: Optional[Callable] = None,
+        drop_tks: bool = True,
+        use_kernel: bool = False,
+    ):
         super().__init__()
         self.k_value = k_value
         self.capacity_factor_train = capacity_factor_train
@@ -68,8 +70,9 @@ class MoeRouter(nn.Module, ABC):
         if router_probs.dim() == expert_indices.dim() == 2:
             router_probs = router_probs.unsqueeze(0)
             expert_indices = expert_indices.unsqueeze(0)
-        assert router_probs.dim() == expert_indices.dim() == 3, \
-            "router_probs must be 3D tensor and expert_indices must be 4D tensor"
+        assert (
+            router_probs.dim() == expert_indices.dim() == 3
+        ), "router_probs must be 3D tensor and expert_indices must be 4D tensor"
 
         # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
         expert_mask = F.one_hot(expert_indices, num_experts)
@@ -122,25 +125,29 @@ class Top1Router(MoeRouter):
         drop_tks (bool, optional): Whether drops tokens in evaluation
     """
 
-    def __init__(self,
-                 capacity_factor_train: float = 1.25,
-                 capacity_factor_eval: float = 2.0,
-                 min_capacity: int = 4,
-                 select_policy: str = "first",
-                 noisy_func: Optional[Callable] = None,
-                 drop_tks: bool = True):
-        super().__init__(k_value=1,
-                         capacity_factor_train=capacity_factor_train,
-                         capacity_factor_eval=capacity_factor_eval,
-                         min_capacity=min_capacity,
-                         noisy_func=noisy_func,
-                         drop_tks=drop_tks)
+    def __init__(
+        self,
+        capacity_factor_train: float = 1.25,
+        capacity_factor_eval: float = 2.0,
+        min_capacity: int = 4,
+        select_policy: str = "first",
+        noisy_func: Optional[Callable] = None,
+        drop_tks: bool = True,
+    ):
+        super().__init__(
+            k_value=1,
+            capacity_factor_train=capacity_factor_train,
+            capacity_factor_eval=capacity_factor_eval,
+            min_capacity=min_capacity,
+            noisy_func=noisy_func,
+            drop_tks=drop_tks,
+        )
         self.select_policy = select_policy
         assert select_policy in {"first", "random"}
         if select_policy == "random":
             self.uniform = torch.distributions.uniform.Uniform(
-                low=torch.tensor(0.0, device=get_current_device()),
-                high=torch.tensor(1.0, device=get_current_device())
+                low=torch.tensor(0.0, device=get_accelerator().get_current_device()),
+                high=torch.tensor(1.0, device=get_accelerator().get_current_device()),
             ).rsample
 
     def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
@@ -216,18 +223,22 @@ class Top2Router(MoeRouter):
         drop_tks (bool, optional): Whether drops tokens in evaluation.
     """
 
-    def __init__(self,
-                 capacity_factor_train: float = 1.25,
-                 capacity_factor_eval: float = 2.0,
-                 min_capacity: int = 4,
-                 noisy_func: Optional[Callable] = None,
-                 drop_tks: bool = True):
-        super().__init__(k_value=2,
-                         capacity_factor_train=capacity_factor_train,
-                         capacity_factor_eval=capacity_factor_eval,
-                         min_capacity=min_capacity,
-                         noisy_func=noisy_func,
-                         drop_tks=drop_tks)
+    def __init__(
+        self,
+        capacity_factor_train: float = 1.25,
+        capacity_factor_eval: float = 2.0,
+        min_capacity: int = 4,
+        noisy_func: Optional[Callable] = None,
+        drop_tks: bool = True,
+    ):
+        super().__init__(
+            k_value=2,
+            capacity_factor_train=capacity_factor_train,
+            capacity_factor_eval=capacity_factor_eval,
+            min_capacity=min_capacity,
+            noisy_func=noisy_func,
+            drop_tks=drop_tks,
+        )
 
     def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
         """
@@ -255,8 +266,8 @@ class Top2Router(MoeRouter):
         top2_idx = torch.argmax(logits_except1, dim=-1)
         mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
 
-        cmask = (mask1 + mask2)    # loss: [s, e]
-        cmask = cmask.float() / 2.0    # div 2 to normalize it to 1
+        cmask = mask1 + mask2  # loss: [s, e]
+        cmask = cmask.float() / 2.0  # div 2 to normalize it to 1
 
         # calculate loss
         expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
@@ -269,7 +280,7 @@ class Top2Router(MoeRouter):
             dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
             capacity = max_num.item()
 
-        rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel)    # rank1: [s, e]
+        rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel)  # rank1: [s, e]
         rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel)
         rank2 += torch.sum(mask1, dim=-2, keepdim=True)
 
@@ -336,15 +347,18 @@ class TopKRouter(MoeRouter):
             oversubscribed / reach capacity.
     """
 
-    def __init__(self,
-                 num_selected_experts: int,
-                 capacity_factor_train: float = 1.25,
-                 capacity_factor_eval: float = 2.0,
-                 min_capacity: int = 4,
-                 noisy_func: Optional[Callable] = None,
-                 drop_tks: bool = True):
-        super().__init__(num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func,
-                         drop_tks)
+    def __init__(
+        self,
+        num_selected_experts: int,
+        capacity_factor_train: float = 1.25,
+        capacity_factor_eval: float = 2.0,
+        min_capacity: int = 4,
+        noisy_func: Optional[Callable] = None,
+        drop_tks: bool = True,
+    ):
+        super().__init__(
+            num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks
+        )
 
     def forward(
         self,
@@ -410,7 +424,7 @@ class TopKRouter(MoeRouter):
         # The combine array will be used for combining expert outputs, scaled by the
         # router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
         # expert_capacity].
-        combine_array = torch.einsum('...te,...tec->...tec', router_probs, dispatch_mask)
+        combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)
 
         return combine_array, dispatch_mask
 
diff --git a/colossalai/moe/utils.py b/colossalai/moe/utils.py
index 5a17a6e0d..e25e7dd48 100644
--- a/colossalai/moe/utils.py
+++ b/colossalai/moe/utils.py
@@ -7,13 +7,12 @@ import torch.distributed as dist
 import torch.nn as nn
 import torch.nn.functional as F
 
+from colossalai.accelerator import get_accelerator
 from colossalai.moe.manager import MOE_MANAGER
 from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor
-from colossalai.utils import get_current_device
 
 
 class ForceFP32Parameter(torch.nn.Parameter):
-
     def half(self, memory_format=None):
         return self.data.clone()
 
@@ -30,8 +29,8 @@ class NormalNoiseGenerator:
 
     def __init__(self, num_experts: int):
         self.normal = torch.distributions.normal.Normal(
-            loc=torch.tensor(0.0, device=get_current_device()),
-            scale=torch.tensor(1.0 / num_experts**2, device=get_current_device()),
+            loc=torch.tensor(0.0, device=get_accelerator().get_current_device()),
+            scale=torch.tensor(1.0 / num_experts**2, device=get_accelerator().get_current_device()),
         ).rsample
 
     def __call__(self, inputs: torch.Tensor):
@@ -52,8 +51,8 @@ class UniformNoiseGenerator:
 
     def __init__(self, eps: float = 1e-2):
         self.uniform = torch.distributions.uniform.Uniform(
-            low=torch.tensor(1.0 - eps, device=get_current_device()),
-            high=torch.tensor(1.0 + eps, device=get_current_device()),
+            low=torch.tensor(1.0 - eps, device=get_accelerator().get_current_device()),
+            high=torch.tensor(1.0 + eps, device=get_accelerator().get_current_device()),
         ).rsample
 
     def __call__(self, inputs: torch.Tensor):
@@ -142,7 +141,7 @@ def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]
     epsize_param_dict = dict()
     for param in model.parameters():
         if not is_moe_tensor(param):
-            ep_size = 1    # set ep_size to 1 for dp parameters
+            ep_size = 1  # set ep_size to 1 for dp parameters
         else:
             ep_size = get_ep_size(param)
         if ep_size not in epsize_param_dict:
@@ -193,18 +192,13 @@ def create_ep_hierarchical_group(
         assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually."
         nproc_per_node = int(nproc_per_node)
     else:
-        assert dist.get_world_size() % nproc_per_node == 0, \
-            "nproc_per_node should be a divisor of world_size."
+        assert dist.get_world_size() % nproc_per_node == 0, "nproc_per_node should be a divisor of world_size."
     num_node = dist.get_world_size() // nproc_per_node
 
     intra_src_rank = None
     ep_intra_node_group = None
     for i in range(num_node):
-        ep_intra_ranks = [
-            i * nproc_per_node + j
-            for j in range(nproc_per_node)
-            if j in ep_group_ranks
-        ]
+        ep_intra_ranks = [i * nproc_per_node + j for j in range(nproc_per_node) if j in ep_group_ranks]
         group = dist.new_group(ep_intra_ranks)
         if rank in ep_intra_ranks:
             assert ep_intra_node_group is None
@@ -212,10 +206,7 @@ def create_ep_hierarchical_group(
             intra_src_rank = ep_intra_ranks[0]
 
     ep_inter_node_group = None
-    ep_inter_ranks = [
-        ep_group_ranks[0] + i * nproc_per_node
-        for i in range(num_node)
-    ]
+    ep_inter_ranks = [ep_group_ranks[0] + i * nproc_per_node for i in range(num_node)]
     if len(ep_inter_ranks) > 1:
         group = dist.new_group(ep_inter_ranks)
         if rank in ep_inter_ranks:
diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py
index 72480526b..20f316c2a 100644
--- a/colossalai/pipeline/schedule/generate.py
+++ b/colossalai/pipeline/schedule/generate.py
@@ -7,10 +7,10 @@ import torch.cuda
 from torch.nn import Module
 from torch.utils._pytree import tree_map
 
+from colossalai.accelerator import get_accelerator
 from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status
 from colossalai.pipeline.p2p import PipelineP2PCommunication
 from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.utils.device import get_current_device
 
 from ._utils import get_batch_size, get_micro_batch, model_forward, to_device
 from .base import PipelineSchedule
@@ -86,7 +86,7 @@ class GenerateSchedule(PipelineSchedule):
         """
         micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size)
         self.microbatch_offset += self.microbatch_size
-        return tree_map(partial(to_device, device=get_current_device()), micro_batch)
+        return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
 
     def _prepare_inputs_for_interval_stage(self):
         """
diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py
index cbf6dd80f..91d936bfd 100644
--- a/colossalai/pipeline/schedule/interleaved_pp.py
+++ b/colossalai/pipeline/schedule/interleaved_pp.py
@@ -6,10 +6,10 @@ import torch.cuda
 from torch.nn import Module
 from torch.utils._pytree import tree_map
 
+from colossalai.accelerator import get_accelerator
 from colossalai.interface import OptimizerWrapper
 from colossalai.pipeline.p2p import PipelineP2PCommunication
 from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.utils.device import get_current_device
 
 from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
 from .base import PipelineSchedule
@@ -56,7 +56,7 @@ class InterleavedSchedule(PipelineSchedule):
         """
         micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size)
         self.microbatch_offset[model_chunk_id] += self.microbatch_size
-        return tree_map(partial(to_device, device=get_current_device()), micro_batch)
+        return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
 
     def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int:
         """Helper method to get the model chunk ID given the iteration number.
@@ -292,7 +292,7 @@ class InterleavedSchedule(PipelineSchedule):
         outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
 
         if return_loss and self.stage_manager.is_last_stage():
-            accum_loss = torch.zeros(1, device=get_current_device())
+            accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
         else:
             accum_loss = None
 
diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py
index fd918cf19..606bf8797 100644
--- a/colossalai/pipeline/schedule/one_f_one_b.py
+++ b/colossalai/pipeline/schedule/one_f_one_b.py
@@ -6,10 +6,10 @@ import torch.cuda
 from torch.nn import Module
 from torch.utils._pytree import tree_map
 
+from colossalai.accelerator import get_accelerator
 from colossalai.interface import ModelWrapper, OptimizerWrapper
 from colossalai.pipeline.p2p import PipelineP2PCommunication
 from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.utils.device import get_current_device
 
 from ._utils import (
     detach,
@@ -80,7 +80,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
         """
         micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size)
         self.microbatch_offset += self.microbatch_size
-        return tree_map(partial(to_device, device=get_current_device()), micro_batch)
+        return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
 
     def recv_forward(self, prev_rank: int = None) -> Any:
         """Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
@@ -297,7 +297,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
 
         outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
         if return_loss and self.stage_manager.is_last_stage():
-            accum_loss = torch.zeros(1, device=get_current_device())
+            accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
         else:
             accum_loss = None
 
diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py
index 96fd3bd7b..0d2cc1b33 100644
--- a/colossalai/shardformer/layer/utils.py
+++ b/colossalai/shardformer/layer/utils.py
@@ -7,7 +7,7 @@ from torch import nn
 from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
 from torch.distributed import ProcessGroup, get_world_size
 
-from colossalai.utils.device import get_current_device, get_rng_state, manual_seed, set_rng_state
+from colossalai.accelerator import get_accelerator
 
 
 class SeqParallelUtils:
@@ -110,10 +110,10 @@ class Randomizer:
         # 1. get the current rng state
         # 2. set the seed and store the rng state
         # 3. recover the original rng state
-        device_original_rng_state = get_rng_state()
-        manual_seed(seed)
-        self.device_rng_state = get_rng_state()
-        set_rng_state(device_original_rng_state)
+        device_original_rng_state = get_accelerator().get_rng_state()
+        get_accelerator().manual_seed(seed)
+        self.device_rng_state = get_accelerator().get_rng_state()
+        get_accelerator().set_rng_state(device_original_rng_state)
 
         # to the same for cpu rng state
         cpu_original_rng_state = torch.get_rng_state()
@@ -122,10 +122,10 @@ class Randomizer:
         torch.set_rng_state(cpu_original_rng_state)
 
     def _set_device_rng_state(self, rng_state):
-        set_rng_state(rng_state)
+        get_accelerator().set_rng_state(rng_state)
 
     def _get_device_rng_state(self):
-        current_state = get_rng_state()
+        current_state = get_accelerator().get_rng_state()
         return current_state
 
     def _set_cpu_rng_state(self, rng_state):
@@ -210,7 +210,7 @@ class Randomizer:
         index = Randomizer.index()
         if dist.is_initialized():
             # convert the index to tensor
-            index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device())
+            index_tensor = torch.tensor(index, dtype=torch.int32, device=get_accelerator().get_current_device())
 
             # all gather the index
             gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
@@ -232,7 +232,7 @@ class Randomizer:
 
         if dist.is_initialized():
             # convert the index to tensor
-            index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device())
+            index_tensor = torch.tensor(index, dtype=torch.int32, device=get_accelerator().get_current_device())
 
             # all gather the index
             gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
diff --git a/colossalai/testing/utils.py b/colossalai/testing/utils.py
index 7cd24b0ad..5f6864ff0 100644
--- a/colossalai/testing/utils.py
+++ b/colossalai/testing/utils.py
@@ -9,7 +9,8 @@ from typing import Any, Callable, List
 import torch
 import torch.multiprocessing as mp
 from packaging import version
-from colossalai.utils.device import empty_cache, reset_max_memory_allocated, reset_peak_memory_stats, synchronize, reset_max_memory_cached, device_count
+
+from colossalai.accelerator import get_accelerator
 
 
 def parameterize(argument: str, values: List[Any]) -> Callable:
@@ -199,7 +200,7 @@ def skip_if_not_enough_gpus(min_gpus: int):
 
     def _wrap_func(f):
         def _execute_by_gpu_num(*args, **kwargs):
-            num_avail_gpu = device_count()
+            num_avail_gpu = get_accelerator().device_count()
             if num_avail_gpu >= min_gpus:
                 f(*args, **kwargs)
 
@@ -263,11 +264,11 @@ def clear_cache_before_run():
 
     def _wrap_func(f):
         def _clear_cache(*args, **kwargs):
-            empty_cache()
-            reset_peak_memory_stats()
-            reset_max_memory_allocated()
-            reset_max_memory_cached()
-            synchronize()
+            get_accelerator().empty_cache()
+            get_accelerator().reset_peak_memory_stats()
+            get_accelerator().reset_max_memory_allocated()
+            get_accelerator().reset_max_memory_cached()
+            get_accelerator().synchronize()
             gc.collect()
             f(*args, **kwargs)
 
diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py
index 0246a35e2..9d33e4668 100644
--- a/colossalai/utils/__init__.py
+++ b/colossalai/utils/__init__.py
@@ -7,17 +7,12 @@ from .common import (
     is_ddp_ignored,
     set_seed,
 )
-from .device import IS_NPU_AVAILABLE, empty_cache, get_current_device, set_device, set_to_cuda, synchronize
 from .multi_tensor_apply import multi_tensor_applier
 from .tensor_detector import TensorDetector
 from .timer import MultiTimer, Timer
 
 __all__ = [
     "conditional_context",
-    "get_current_device",
-    "synchronize",
-    "empty_cache",
-    "set_to_cuda",
     "Timer",
     "MultiTimer",
     "multi_tensor_applier",
@@ -28,6 +23,4 @@ __all__ = [
     "free_storage",
     "set_seed",
     "is_ddp_ignored",
-    "set_device",
-    "IS_NPU_AVAILABLE",
 ]
diff --git a/colossalai/utils/device.py b/colossalai/utils/device.py
deleted file mode 100644
index c70dbdaa5..000000000
--- a/colossalai/utils/device.py
+++ /dev/null
@@ -1,223 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from typing import Any, Dict, List, Optional, Tuple, Callable
-
-import torch
-import torch.distributed as dist
-
-IS_NPU_AVAILABLE: bool = False
-try:
-    import torch_npu  # noqa
-
-    IS_NPU_AVAILABLE = torch.npu.is_available()
-except ImportError:
-    pass
-
-
-def set_to_cuda(models):
-    """Send model to gpu.
-
-    :param models: nn.module or a list of module
-    """
-    if isinstance(models, list) and len(models) > 1:
-        ret = []
-        for model in models:
-            ret.append(model.to(get_current_device()))
-        return ret
-    elif isinstance(models, list):
-        return models[0].to(get_current_device())
-    else:
-        return models.to(get_current_device())
-
-
-def get_current_device() -> torch.device:
-    """
-    Returns currently selected device (gpu/cpu).
-    If cuda available, return gpu, otherwise return cpu.
-    """
-    if torch.cuda.is_available():
-        return torch.device(f"cuda:{torch.cuda.current_device()}")
-    elif IS_NPU_AVAILABLE:
-        return torch.device(f"npu:{torch.npu.current_device()}")
-    else:
-        return torch.device("cpu")
-
-
-def _dispatch_device_func(fn_name: str, *args, **kwargs):
-    if torch.cuda.is_available():
-        return getattr(torch.cuda, fn_name)(*args, **kwargs)
-    elif IS_NPU_AVAILABLE:
-        return getattr(torch.npu, fn_name)(*args, **kwargs)
-    else:
-        raise RuntimeError("No device available")
-
-
-# device semantics
-
-
-def can_device_access_peer(device, peer_device) -> bool:
-    return _dispatch_device_func("can_device_access_peer", device, peer_device)
-
-
-def current_device() -> int:
-    return _dispatch_device_func("current_device")
-
-
-def current_stream(device=None):
-    return _dispatch_device_func("current_stream", device)
-
-
-def default_stream(device=None):
-    return _dispatch_device_func("default_stream", device)
-
-
-def device_count() -> int:
-    return _dispatch_device_func("device_count")
-
-
-def get_device_capability(device=None) -> Tuple[int, int]:
-    return _dispatch_device_func("get_device_capability", device)
-
-
-def get_device_name(device=None) -> str:
-    return _dispatch_device_func("get_device_name", device)
-
-
-def get_device_properties(device):
-    return _dispatch_device_func("get_device_properties", device)
-
-
-def set_device(index: Optional[int] = None) -> None:
-    if index is None:
-        index = dist.get_rank() % device_count()
-    _dispatch_device_func("set_device", index)
-
-
-def set_stream(stream_):
-    return _dispatch_device_func("set_stream", stream_)
-
-
-def stream(stream_):
-    return _dispatch_device_func("stream", stream_)
-
-
-def synchronize():
-    return _dispatch_device_func("synchronize")
-
-
-def utilization(device=None) -> int:
-    return _dispatch_device_func("utilization", device)
-
-
-# random number generator
-
-
-def get_rng_state(device="cuda") -> torch.Tensor:
-    return _dispatch_device_func("get_rng_state", device)
-
-
-def get_rng_state_all() -> List[torch.Tensor]:
-    return _dispatch_device_func("get_rng_state_all")
-
-
-def set_rng_state(new_state: torch.ByteTensor, device="cuda") -> None:
-    return _dispatch_device_func("set_rng_state", new_state, device)
-
-
-def set_rng_state_all(new_states: List[torch.ByteTensor]) -> None:
-    return _dispatch_device_func("set_rng_state_all", new_states)
-
-
-def manual_seed(seed: int) -> None:
-    return _dispatch_device_func("manual_seed", seed)
-
-
-def manual_seed_all(seed: int) -> None:
-    return _dispatch_device_func("manual_seed_all", seed)
-
-
-def seed() -> None:
-    return _dispatch_device_func("seed")
-
-
-def seed_all() -> None:
-    return _dispatch_device_func("seed_all")
-
-
-def initial_seed() -> int:
-    return _dispatch_device_func("initial_seed")
-
-
-# streams and events
-
-
-def Stream(device=None, priority=0, **kwargs):
-    return _dispatch_device_func("Stream", device, priority, **kwargs)
-
-
-def Event(enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
-    return _dispatch_device_func("Event", enable_timing, blocking, interprocess)
-
-
-# memory management
-
-
-def empty_cache() -> None:
-    return _dispatch_device_func("empty_cache")
-
-
-def memory_stats(device=None) -> Dict[str, Any]:
-    return _dispatch_device_func("memory_stats", device)
-
-
-def memory_summary(device=None, abbreviated=False) -> str:
-    return _dispatch_device_func("memory_summary", device, abbreviated)
-
-
-def memory_snapshot():
-    return _dispatch_device_func("memory_snapshot")
-
-
-def memory_allocated(device=None) -> int:
-    return _dispatch_device_func("memory_allocated", device)
-
-
-def max_memory_allocated(device=None) -> int:
-    return _dispatch_device_func("max_memory_allocated", device)
-
-
-def reset_max_memory_allocated(device=None) -> None:
-    return _dispatch_device_func("reset_max_memory_allocated", device)
-
-
-def reset_max_memory_cached(device=None) -> None:
-    return _dispatch_device_func("reset_max_memory_cached", device)
-
-
-def memory_reserved(device=None) -> int:
-    return _dispatch_device_func("memory_reserved", device)
-
-
-def max_memory_reserved(device=None) -> int:
-    return _dispatch_device_func("max_memory_reserved", device)
-
-
-def set_per_process_memory_fraction(fraction: float, device=None) -> None:
-    return _dispatch_device_func("set_per_process_memory_fraction", fraction, device)
-
-
-def reset_peak_memory_stats(device=None) -> None:
-    return _dispatch_device_func("reset_peak_memory_stats", device)
-
-
-# amp
-
-
-def autocast() -> Callable:
-    if torch.cuda.is_available():
-        return torch.cuda.amp.autocast()
-    elif IS_NPU_AVAILABLE:
-        return torch.npu.amp.autocast()
-    else:
-        raise RuntimeError("No device available")
diff --git a/colossalai/utils/timer.py b/colossalai/utils/timer.py
index 8ab6b46f2..0fbdd0932 100644
--- a/colossalai/utils/timer.py
+++ b/colossalai/utils/timer.py
@@ -3,7 +3,7 @@
 import time
 from typing import Tuple
 
-from .device import synchronize
+from colossalai.accelerator import get_accelerator
 
 
 class Timer:
@@ -21,13 +21,13 @@ class Timer:
 
     @property
     def current_time(self) -> float:
-        synchronize()
+        get_accelerator().synchronize()
         return time.time()
 
     def start(self):
         """Firstly synchronize cuda, reset the clock and then start the timer."""
         self._elapsed = 0
-        synchronize()
+        get_accelerator().synchronize()
         self._start_time = time.time()
         self._started = True
 
@@ -44,7 +44,7 @@ class Timer:
         Returns:
             int: Start-stop interval.
         """
-        synchronize()
+        get_accelerator().synchronize()
         end_time = time.time()
         elapsed = end_time - self._start_time
         if keep_in_history:
diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py
index defc6c4cb..7a9f58701 100644
--- a/colossalai/zero/gemini/chunk/chunk.py
+++ b/colossalai/zero/gemini/chunk/chunk.py
@@ -6,8 +6,7 @@ import torch
 import torch.distributed as dist
 from torch.distributed import ProcessGroup
 
-from colossalai.utils import get_current_device
-from colossalai.utils.device import IS_NPU_AVAILABLE
+from colossalai.accelerator import get_accelerator
 
 
 class TensorState(Enum):
@@ -107,7 +106,7 @@ class Chunk:
         self.valid_end = self.shard_size
 
         self.dtype = dtype
-        device = init_device or get_current_device()
+        device = init_device or get_accelerator().get_current_device()
 
         # chunk_temp is a global chunk, which only exists during building the chunks.
         self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device)  # keep all zero
@@ -125,7 +124,7 @@ class Chunk:
         # configure the init device of the shard
         # no-offload default: fp16, fp32 -> CUDA
         # offload default: fp16, fp32 -> CPU
-        self.shard_device = torch.device("cpu") if cpu_shard_init else get_current_device()
+        self.shard_device = torch.device("cpu") if cpu_shard_init else get_accelerator().get_current_device()
 
         self.chunk_mem = self.chunk_size * self.chunk_temp.element_size()
         self.shard_mem = self.chunk_mem // self.pg_size
@@ -192,10 +191,7 @@ class Chunk:
         if self.chunk_temp is not None:
             return self.chunk_temp.device.type
         else:
-            if self.is_gathered or self.cuda_shard is not None:
-                return "npu" if IS_NPU_AVAILABLE else "cuda"
-            else:
-                return "cpu"
+            return get_accelerator().name
 
     @property
     def payload(self) -> torch.Tensor:
@@ -297,7 +293,7 @@ class Chunk:
             self.valid_end = self.utilized_size - self.shard_begin
 
         if self.chunk_temp.device.type == "cpu":
-            self.cuda_global_chunk = self.chunk_temp.to(get_current_device())
+            self.cuda_global_chunk = self.chunk_temp.to(get_accelerator().get_current_device())
             self.__update_tensors_ptr()
         else:
             self.cuda_global_chunk = self.chunk_temp
@@ -334,12 +330,12 @@ class Chunk:
             return
 
         if device.type == "cuda" or device.type == "npu":
-            assert device == get_current_device(), "can't move chunk to another device"
+            assert device == get_accelerator().get_current_device(), "can't move chunk to another device"
 
             if self.cuda_shard:
                 return
 
-            self.cuda_shard = self.cpu_shard.to(get_current_device())
+            self.cuda_shard = self.cpu_shard.to(get_accelerator().get_current_device())
 
             if not self.pin_memory:
                 self.cpu_shard = None
@@ -394,7 +390,9 @@ class Chunk:
             if self.extra_dp_group is not None:
                 dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group)
         else:
-            self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())
+            self.cuda_shard = torch.empty(
+                self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device()
+            )
 
             input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
             dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
@@ -533,7 +531,7 @@ class Chunk:
         # only be called when optimizer state is in CPU memory
         # the grad and param should be in the same device
         assert self.cuda_shard is None
-        temp = optim_chunk.cpu_shard.to(get_current_device())
+        temp = optim_chunk.cpu_shard.to(get_accelerator().get_current_device())
         # avoid to transform FP32 in CPU
         self.cuda_shard = temp.to(self.dtype)
 
@@ -631,7 +629,7 @@ class Chunk:
             grad_chunk.valid_end = self.valid_end
 
             if grad_chunk.chunk_temp.device.type == "cpu":
-                grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_current_device())
+                grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_accelerator().get_current_device())
             else:
                 grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp
             grad_chunk.chunk_temp = None
diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py
index 5f4f37c26..5bc662a61 100644
--- a/colossalai/zero/gemini/chunk/manager.py
+++ b/colossalai/zero/gemini/chunk/manager.py
@@ -5,7 +5,8 @@ import torch
 import torch.distributed as dist
 from torch.distributed import ProcessGroup
 
-from colossalai.utils import free_storage, get_current_device
+from colossalai.accelerator import get_accelerator
+from colossalai.utils import free_storage
 
 from .chunk import Chunk, ChunkFullError, TensorState
 
@@ -20,7 +21,7 @@ class ChunkManager:
     """
 
     def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None:
-        self.device = init_device or get_current_device()
+        self.device = init_device or get_accelerator().get_current_device()
         self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
         self.kwargs_config = chunk_configuration
         for k, v in self.kwargs_config.items():
@@ -107,7 +108,7 @@ class ChunkManager:
             return
         self.__sub_memory_usage(chunk.memory_usage)
         if chunk.device_type == "cpu":
-            chunk.shard_move(get_current_device())
+            chunk.shard_move(get_accelerator().get_current_device())
         self.__add_accessed_chunk(chunk)
         self.__add_memory_usage(chunk.memory_usage)
 
@@ -276,7 +277,10 @@ class ChunkManager:
                 accumulated_grad = chunk.grad_chunk.cuda_shard.clone().detach().mul_(chunk.pg_size)
             else:
                 accumulated_grad = (
-                    chunk.grad_chunk.cpu_shard.to(get_current_device()).clone().detach().mul_(chunk.pg_size)
+                    chunk.grad_chunk.cpu_shard.to(get_accelerator().get_current_device())
+                    .clone()
+                    .detach()
+                    .mul_(chunk.pg_size)
                 )
             accumulated_grad_gathered = False
 
diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py
index 5217b8036..79831cf33 100644
--- a/colossalai/zero/gemini/gemini_ddp.py
+++ b/colossalai/zero/gemini/gemini_ddp.py
@@ -10,6 +10,7 @@ import torch.nn as nn
 from torch.distributed import ProcessGroup
 from torch.distributed.distributed_c10d import _get_default_group
 
+from colossalai.accelerator import get_accelerator
 from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
 from colossalai.interface import ModelWrapper
 from colossalai.lazy import LazyTensor
@@ -27,7 +28,7 @@ from colossalai.tensor.d_tensor import (
     is_distributed_tensor,
 )
 from colossalai.tensor.param_op_hook import ColoParamOpHookManager
-from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored
+from colossalai.utils import _cast_float, free_storage, is_ddp_ignored
 
 from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
 from .gemini_hook import GeminiZeROHook
@@ -766,7 +767,7 @@ class GeminiDDP(ModelWrapper):
 
             # move ignored parameters to CUDA
             if is_ddp_ignored(p):
-                p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision)
+                p.data = p.data.to(device=get_accelerator().get_current_device(), dtype=self.mixed_precision)
                 continue
 
             # create a fp16 parameter
@@ -815,7 +816,7 @@ class GeminiDDP(ModelWrapper):
         for buffer in self.module.buffers():
             if isinstance(buffer, LazyTensor):
                 buffer.materialize()
-            buffer.data = buffer.to(get_current_device())
+            buffer.data = buffer.to(get_accelerator().get_current_device())
             if torch.is_floating_point(buffer):
                 buffer.data = buffer.to(self.mixed_precision)
 
diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py
index 8f828bd6c..09fad1e77 100644
--- a/colossalai/zero/gemini/gemini_optimizer.py
+++ b/colossalai/zero/gemini/gemini_optimizer.py
@@ -11,6 +11,7 @@ from torch.distributed import ProcessGroup
 from torch.nn import Parameter
 from torch.optim import Optimizer
 
+from colossalai.accelerator import get_accelerator
 from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
 from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
 from colossalai.interface import OptimizerWrapper
@@ -26,7 +27,7 @@ from colossalai.tensor.d_tensor import (
     is_customized_distributed_tensor,
     is_distributed_tensor,
 )
-from colossalai.utils import disposable, get_current_device, is_ddp_ignored
+from colossalai.utils import disposable, is_ddp_ignored
 
 from .chunk import Chunk, ChunkManager
 from .gemini_ddp import GeminiDDP
@@ -233,7 +234,7 @@ class GeminiOptimizer(OptimizerWrapper):
 
             grad_chunk.l2_norm = None  # clear l2 norm
 
-        comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device())
+        comm_buffer = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device())
         for group, part_norm in group_to_norm.items():
             comm_buffer.fill_(part_norm)
             dist.all_reduce(comm_buffer, group=group)
@@ -314,10 +315,10 @@ class GeminiOptimizer(OptimizerWrapper):
                         continue
 
                     if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem:
-                        self.chunk_manager.move_chunk(chunk32, get_current_device())
+                        self.chunk_manager.move_chunk(chunk32, get_accelerator().get_current_device())
                         # stores grad now
-                        self.chunk_manager.move_chunk(chunk16, get_current_device())
-                        self.module.set_chunk_grad_device(chunk16, get_current_device())
+                        self.chunk_manager.move_chunk(chunk16, get_accelerator().get_current_device())
+                        self.module.set_chunk_grad_device(chunk16, get_accelerator().get_current_device())
                         fp32_params_used_cuda_margin_mem += chunk32.payload_mem
 
             for group in self.param_groups:
@@ -328,7 +329,7 @@ class GeminiOptimizer(OptimizerWrapper):
                         state = self.optim.state[fake_param]
                         for k, v in state.items():
                             if isinstance(v, torch.Tensor):
-                                state[k] = v.to(get_current_device())
+                                state[k] = v.to(get_accelerator().get_current_device())
 
     def _register_states_(self):
         for group in self.optim.param_groups:
@@ -551,7 +552,7 @@ class GeminiOptimizer(OptimizerWrapper):
         self,
         param_id: int,
         state_names: list,
-        device: torch.device = get_current_device(),
+        device: torch.device = get_accelerator().get_current_device(),
         dtype: torch.dtype = torch.float32,
     ) -> torch.Tensor:
         """
diff --git a/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py
index b5e40a817..e302805df 100644
--- a/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py
+++ b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py
@@ -1,6 +1,6 @@
 from typing import Optional
 
-from colossalai.utils import get_current_device
+from colossalai.accelerator import get_accelerator
 from colossalai.zero.gemini.chunk import ChunkManager
 
 from .memory_stats import MemStats
@@ -33,4 +33,4 @@ class ChunkMemStatsCollector(MemStatsCollector):
     def cuda_margin_mem(self) -> float:
         from colossalai.legacy.utils.memory import colo_device_memory_capacity
 
-        return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda
+        return colo_device_memory_capacity(get_accelerator().get_current_device()) - self._memstats.max_overall_cuda
diff --git a/colossalai/zero/gemini/memory_tracer/memory_monitor.py b/colossalai/zero/gemini/memory_tracer/memory_monitor.py
index 513a6326d..82c8e9dab 100644
--- a/colossalai/zero/gemini/memory_tracer/memory_monitor.py
+++ b/colossalai/zero/gemini/memory_tracer/memory_monitor.py
@@ -5,7 +5,7 @@ from time import sleep, time
 
 import torch
 
-from colossalai.utils import get_current_device
+from colossalai.accelerator import get_accelerator
 
 
 class MemoryMonitor:
@@ -77,7 +77,7 @@ class AsyncMemoryMonitor(MemoryMonitor):
         super().__init__()
         self.keep_measuring = False
 
-        current_device = get_current_device()
+        current_device = get_accelerator().get_current_device()
 
         def _set_cuda_device():
             torch.cuda.set_device(current_device)
@@ -116,7 +116,7 @@ class AsyncMemoryMonitor(MemoryMonitor):
         while self.keep_measuring:
             max_usage = max(
                 max_usage,
-                colo_device_memory_used(get_current_device()),
+                colo_device_memory_used(get_accelerator().get_current_device()),
             )
             sleep(self.interval)
         return max_usage
diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py
index 8a74eb587..388999549 100644
--- a/colossalai/zero/gemini/placement_policy.py
+++ b/colossalai/zero/gemini/placement_policy.py
@@ -6,8 +6,8 @@ from typing import Dict, List, Optional, Tuple, Type
 
 import torch
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.utils.memory import colo_device_memory_capacity
-from colossalai.utils import get_current_device
 from colossalai.zero.gemini.chunk import Chunk
 
 from .chunk import Chunk, ChunkManager
@@ -85,7 +85,7 @@ class StaticPlacementPolicy(PlacementPolicy):
             # init offload optim settings
             # keep gathered chunks are in CUDA
             if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem:
-                device = get_current_device()
+                device = get_accelerator().get_current_device()
             else:
                 device = torch.device("cpu")
                 # real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here
@@ -140,7 +140,7 @@ class AutoPlacementPolicy(PlacementPolicy):
             int: the volume of memory that is evicted
         """
         start = time()
-        cuda_capacity = colo_device_memory_capacity(get_current_device())
+        cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
         used_cuda_model_data = self.chunk_manager.total_mem["cuda"]
         if warmup:
             # We designate a part of CUDA memory for model data in warmup iterations.
@@ -194,7 +194,7 @@ class AutoPlacementPolicy(PlacementPolicy):
             # init offload optim settings
             # keep gathered chunks are in CUDA
             if chunk.keep_gathered:
-                grads_device_map[p] = get_current_device()
+                grads_device_map[p] = get_accelerator().get_current_device()
             else:
                 grads_device_map[p] = torch.device("cpu")
 
diff --git a/colossalai/zero/gemini/utils.py b/colossalai/zero/gemini/utils.py
index 5305953fe..b563ea5b2 100644
--- a/colossalai/zero/gemini/utils.py
+++ b/colossalai/zero/gemini/utils.py
@@ -6,7 +6,7 @@ import torch
 import torch.distributed as dist
 import torch.nn as nn
 
-from colossalai.utils import get_current_device
+from colossalai.accelerator import get_accelerator
 
 from .chunk import Chunk
 
@@ -18,11 +18,11 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk, dtype: torch.dtype):
     if chunk.cuda_shard is not None:
         shard_temp = chunk.cuda_shard
     else:
-        shard_temp = chunk.cpu_shard.to(get_current_device())
+        shard_temp = chunk.cpu_shard.to(get_accelerator().get_current_device())
 
     shard_temp = shard_temp.to(dtype)
 
-    total_temp = torch.zeros(chunk.chunk_size, dtype=dtype, device=get_current_device())
+    total_temp = torch.zeros(chunk.chunk_size, dtype=dtype, device=get_accelerator().get_current_device())
     gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0))
     dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg)
 
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index c1b35ee17..81eba6fe5 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -12,7 +12,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
 from torch.distributed import ProcessGroup
 from torch.optim import Optimizer
 
-import colossalai.utils.device as device_utils
+from colossalai.accelerator import get_accelerator
 from colossalai.amp.naive_amp.mixed_precision_mixin import (
     BF16MixedPrecisionMixin,
     FP16MixedPrecisionMixin,
@@ -22,9 +22,6 @@ from colossalai.interface import OptimizerWrapper
 from colossalai.logging import get_dist_logger
 from colossalai.tensor.moe_tensor.api import is_moe_tensor
 
-# from colossalai.tensor import ColoParameter, ProcessGroup
-from colossalai.utils.device import IS_NPU_AVAILABLE, get_current_device
-
 from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
 from .bookkeeping import BucketStore, GradientStore, ParameterStore
 
@@ -183,7 +180,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
         # intialize communication stream for
         # communication-compuation overlapping
         if self._overlap_communication:
-            self._comm_stream = device_utils.Stream()
+            self._comm_stream = get_accelerator().Stream()
 
         # reduction hook is only used if overlapping communication
         # or stage 2 is used
@@ -217,7 +214,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
         return len(self._working_param_groups)
 
     def _sanity_checks(self):
-        assert torch.cuda.is_available() or IS_NPU_AVAILABLE, "device is required"
+        assert get_accelerator().name in ["cuda", "npu"], "device is required"
         for param_group in self.optim.param_groups:
             group_params = param_group["params"]
             for param in group_params:
@@ -228,7 +225,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
     def _create_master_param_current_rank(self, param_list):
         # split each param evenly by world size
         params_current_rank = []
-        device = "cpu" if self._cpu_offload else get_current_device()
+        device = "cpu" if self._cpu_offload else get_accelerator().get_current_device()
 
         for param in param_list:
             padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size
@@ -340,11 +337,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
                     if len(moe_grad_list) > 0:
                         moe_flat_grads.record_stream(stream)
                 # waiting for ops in the default stream finishing
-                stream.wait_stream(device_utils.current_stream())
+                stream.wait_stream(get_accelerator().current_stream())
             else:
-                stream = device_utils.current_stream()
+                stream = get_accelerator().current_stream()
 
-            with device_utils.stream(stream):
+            with get_accelerator().stream(stream):
                 group_id = self._bucket_store.current_group_id
 
                 if self.moe_extra_dp_pg is None:
@@ -486,7 +483,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
 
         # clear reduced grads
         if self._overlap_communication:
-            device_utils.synchronize()
+            get_accelerator().synchronize()
 
         self.zero_grad()
 
@@ -505,7 +502,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
 
         # clear reduced grads
         if self._overlap_communication:
-            device_utils.synchronize()
+            get_accelerator().synchronize()
 
         self.zero_grad()
 
@@ -621,7 +618,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
             release_param_grad(self._master_param_groups_of_current_rank[group_id])
 
         # update working partition updated by the current rank
-        device = get_current_device()
+        device = get_accelerator().get_current_device()
         for group_id in range(self.num_param_groups):
             master_working_param = self.optim.param_groups[group_id]["params"]
             for idx, splited_param in enumerate(master_working_param):
@@ -661,7 +658,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
         norm_type = float(norm_type)
         if norm_type == inf:
             total_norm = max(grad.data.abs().max() for grad in gradients)
-            total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float)
+            total_norm_cuda = torch.tensor(
+                [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float
+            )
             dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
             total_norm = total_norm_cuda.item()
 
@@ -673,7 +672,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
 
             # Sum across all model parallel GPUs.
             total_norm_exponentiated_cuda = torch.tensor(
-                [float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float
+                [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float
             )
             torch.distributed.all_reduce(
                 total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
@@ -765,7 +764,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
             Dict: the pytorch form state_dict
         """
         zero_state = dict()
-        device = get_current_device()
+        device = get_accelerator().get_current_device()
         for param, state in self.optim.state.items():
             zero_state[param] = copy.deepcopy(state)
             for k, v in state.items():
@@ -827,7 +826,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
         ret_block = dict()
         ret_block_size = 0
 
-        device = get_current_device()
+        device = get_accelerator().get_current_device()
         local_states = self.optim.state_dict()["state"]
         for param_idx, states in local_states.items():
             current_block_size = 0
diff --git a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
index 7a0e3b1a0..e87eafb6e 100644
--- a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
+++ b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
@@ -45,7 +45,6 @@ from colossalai.booster import Booster
 from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
 from colossalai.cluster import DistCoordinator
 from colossalai.nn.optimizer import HybridAdam
-from colossalai.utils import get_current_device
 ```
 ## Define Plugin
 Create a `HybridParallelPlugin` object and specify the desired parallelism strategies to be used. In this example, both pipeline parallelism and ZeRO-1 are used simultaneously.
@@ -149,7 +148,7 @@ model, optimizer, _criterion, _, lr_scheduler = booster.boost(
 
 ## Training GPT-2 using hybrid parallelism
 
-In the previous tutorial, We've explained how to inject various parallelism features into the model and its training components using the Booster and `HybridParallelPlugin`. Now we can start model training. 
+In the previous tutorial, We've explained how to inject various parallelism features into the model and its training components using the Booster and `HybridParallelPlugin`. Now we can start model training.
 Define a training function. When pipeline parallelism is used, you need to call `booster.execute_pipeline` to schedule the stages of model training.
 ```python
 def train_epoch(
@@ -204,4 +203,4 @@ Training the gpt-2 model
 for epoch in range(NUM_EPOCHS):
     train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
 ```
-<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 train_gpt_using_hybrid_parallelism.py  -->
\ No newline at end of file
+<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 train_gpt_using_hybrid_parallelism.py  -->
diff --git a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
index 117406980..ae941b489 100644
--- a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
+++ b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
@@ -43,7 +43,6 @@ from colossalai.booster import Booster
 from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
 from colossalai.cluster import DistCoordinator
 from colossalai.nn.optimizer import HybridAdam
-from colossalai.utils import get_current_device
 ```
 ### 定义plugin
 定义一个[`HybridParallelPlugin`](../basics/booster_plugins.md)对象,指定所需要使用的并行策略,在该例子中,同时使用了流水线并行和zero1.
@@ -201,4 +200,4 @@ def train_epoch(
 for epoch in range(NUM_EPOCHS):
     train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
 ```
-<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 train_gpt_using_hybrid_parallelism.py  -->
\ No newline at end of file
+<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 train_gpt_using_hybrid_parallelism.py  -->
diff --git a/examples/community/roberta/pretraining/run_pretraining.py b/examples/community/roberta/pretraining/run_pretraining.py
index 5396de693..40b11d649 100644
--- a/examples/community/roberta/pretraining/run_pretraining.py
+++ b/examples/community/roberta/pretraining/run_pretraining.py
@@ -16,10 +16,10 @@ from utils.global_vars import get_tensorboard_writer, get_timers, set_global_var
 from utils.logger import Logger
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.context import ParallelMode
 from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
 from colossalai.tensor import ProcessGroup, ShardSpec
-from colossalai.utils import get_current_device
 from colossalai.utils.model.colo_init_context import ColoInitContext
 
 
@@ -53,7 +53,7 @@ def main():
     set_global_variables(launch_time, args.tensorboard_path)
 
     world_size = torch.distributed.get_world_size()
-    get_current_device()
+    get_accelerator().get_current_device()
 
     # build model, optimizer and criterion
     if args.distplan.startswith("CAI"):
@@ -67,7 +67,10 @@ def main():
 
         # build GPT model
         with ColoInitContext(
-            device=get_current_device(), dtype=torch.half, default_dist_spec=default_dist_spec, default_pg=shard_pg
+            device=get_accelerator().get_current_device(),
+            dtype=torch.half,
+            default_dist_spec=default_dist_spec,
+            default_pg=shard_pg,
         ):
             config, model, numel = get_model(args, logger)
 
@@ -78,7 +81,7 @@ def main():
         elif args.distplan == "CAI_Gemini":
             gemini_config = dict(
                 strict_ddp_mode=args.tp_degree == 1,
-                device=get_current_device(),
+                device=get_accelerator().get_current_device(),
                 placement_policy=args.placement,
                 pin_memory=True,
                 hidden_dim=model.config.hidden_size,
diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py
index 1a7f8da7f..cc2b2ebc7 100644
--- a/examples/images/dreambooth/train_dreambooth_colossalai.py
+++ b/examples/images/dreambooth/train_dreambooth_colossalai.py
@@ -20,11 +20,11 @@ from tqdm.auto import tqdm
 from transformers import AutoTokenizer, PretrainedConfig
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.booster import Booster
 from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
 from colossalai.logging import disable_existing_loggers, get_dist_logger
 from colossalai.nn.optimizer import HybridAdam
-from colossalai.utils import get_current_device
 
 disable_existing_loggers()
 logger = get_dist_logger()
@@ -386,7 +386,7 @@ def main(args):
         cur_class_images = len(list(class_images_dir.iterdir()))
 
         if cur_class_images < args.num_class_images:
-            torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32
+            torch_dtype = torch.float16 if get_accelerator().get_current_device() == "cuda" else torch.float32
             pipeline = DiffusionPipeline.from_pretrained(
                 args.pretrained_model_name_or_path,
                 torch_dtype=torch_dtype,
@@ -401,7 +401,7 @@ def main(args):
             sample_dataset = PromptDataset(args.class_prompt, num_new_images)
             sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
 
-            pipeline.to(get_current_device())
+            pipeline.to(get_accelerator().get_current_device())
 
             for example in tqdm(
                 sample_dataloader,
@@ -578,8 +578,8 @@ def main(args):
     # Move text_encode and vae to gpu.
     # For mixed precision training we cast the text_encoder and vae weights to half-precision
     # as these models are only used for inference, keeping weights in full precision is not required.
-    vae.to(get_current_device(), dtype=weight_dtype)
-    text_encoder.to(get_current_device(), dtype=weight_dtype)
+    vae.to(get_accelerator().get_current_device(), dtype=weight_dtype)
+    text_encoder.to(get_accelerator().get_current_device(), dtype=weight_dtype)
 
     # We need to recalculate our total training steps as the size of the training dataloader may have changed.
     num_update_steps_per_epoch = math.ceil(len(train_dataloader))
@@ -613,7 +613,7 @@ def main(args):
             torch.cuda.reset_peak_memory_stats()
             # Move batch to gpu
             for key, value in batch.items():
-                batch[key] = value.to(get_current_device(), non_blocking=True)
+                batch[key] = value.to(get_accelerator().get_current_device(), non_blocking=True)
 
             # Convert images to latent space
             optimizer.zero_grad()
diff --git a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py
index ea6dde8bb..227488abe 100644
--- a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py
+++ b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py
@@ -21,13 +21,13 @@ from tqdm.auto import tqdm
 from transformers import AutoTokenizer, PretrainedConfig
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.booster import Booster
 from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
 from colossalai.legacy.context.parallel_mode import ParallelMode
 from colossalai.legacy.core import global_context as gpc
 from colossalai.logging import disable_existing_loggers, get_dist_logger
 from colossalai.nn.optimizer import HybridAdam
-from colossalai.utils import get_current_device
 
 disable_existing_loggers()
 logger = get_dist_logger()
@@ -385,7 +385,7 @@ def main(args):
         cur_class_images = len(list(class_images_dir.iterdir()))
 
         if cur_class_images < args.num_class_images:
-            torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32
+            torch_dtype = torch.float16 if get_accelerator().get_current_device() == "cuda" else torch.float32
             pipeline = DiffusionPipeline.from_pretrained(
                 args.pretrained_model_name_or_path,
                 torch_dtype=torch_dtype,
@@ -400,7 +400,7 @@ def main(args):
             sample_dataset = PromptDataset(args.class_prompt, num_new_images)
             sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
 
-            pipeline.to(get_current_device())
+            pipeline.to(get_accelerator().get_current_device())
 
             for example in tqdm(
                 sample_dataloader,
@@ -598,8 +598,8 @@ def main(args):
     # Move text_encode and vae to gpu.
     # For mixed precision training we cast the text_encoder and vae weights to half-precision
     # as these models are only used for inference, keeping weights in full precision is not required.
-    vae.to(get_current_device(), dtype=weight_dtype)
-    text_encoder.to(get_current_device(), dtype=weight_dtype)
+    vae.to(get_accelerator().get_current_device(), dtype=weight_dtype)
+    text_encoder.to(get_accelerator().get_current_device(), dtype=weight_dtype)
 
     # We need to recalculate our total training steps as the size of the training dataloader may have changed.
     num_update_steps_per_epoch = math.ceil(len(train_dataloader))
@@ -633,7 +633,7 @@ def main(args):
             torch.cuda.reset_peak_memory_stats()
             # Move batch to gpu
             for key, value in batch.items():
-                batch[key] = value.to(get_current_device(), non_blocking=True)
+                batch[key] = value.to(get_accelerator().get_current_device(), non_blocking=True)
 
             # Convert images to latent space
             optimizer.zero_grad()
diff --git a/examples/images/resnet/train.py b/examples/images/resnet/train.py
index 13df516d4..5871bbf87 100644
--- a/examples/images/resnet/train.py
+++ b/examples/images/resnet/train.py
@@ -13,12 +13,12 @@ from torch.utils.data import DataLoader
 from tqdm import tqdm
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.booster import Booster
 from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
 from colossalai.booster.plugin.dp_plugin_base import DPPluginBase
 from colossalai.cluster import DistCoordinator
 from colossalai.nn.optimizer import HybridAdam
-from colossalai.utils import get_current_device
 
 # ==============================
 # Prepare Hyperparameters
@@ -53,8 +53,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl
 @torch.no_grad()
 def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float:
     model.eval()
-    correct = torch.zeros(1, dtype=torch.int64, device=get_current_device())
-    total = torch.zeros(1, dtype=torch.int64, device=get_current_device())
+    correct = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device())
+    total = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device())
     for images, labels in test_dataloader:
         images = images.cuda()
         labels = labels.cuda()
diff --git a/examples/images/vit/vit_benchmark.py b/examples/images/vit/vit_benchmark.py
index b770bc9cf..078017324 100644
--- a/examples/images/vit/vit_benchmark.py
+++ b/examples/images/vit/vit_benchmark.py
@@ -33,9 +33,10 @@ def get_data_batch(batch_size, num_labels, num_channels=3, height=224, width=224
 
 
 def colo_memory_cap(size_in_GB):
-    from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
+    from colossalai.accelerator import get_accelerator
+    from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction
 
-    cuda_capacity = colo_device_memory_capacity(get_current_device())
+    cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
     if size_in_GB * (1024**3) < cuda_capacity:
         colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
         print(f"Limiting GPU memory usage to {size_in_GB} GB")
diff --git a/examples/inference/benchmark_llama.py b/examples/inference/benchmark_llama.py
index 9a26098b3..26cac977a 100644
--- a/examples/inference/benchmark_llama.py
+++ b/examples/inference/benchmark_llama.py
@@ -6,10 +6,9 @@ import torch.distributed as dist
 import transformers
 
 import colossalai
-import colossalai.utils.device as device_utils
+from colossalai.accelerator import get_accelerator
 from colossalai.inference import InferenceEngine
 from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
-from colossalai.utils.device import get_current_device
 
 GIGABYTE = 1024**3
 MEGABYTE = 1024 * 1024
@@ -52,7 +51,7 @@ CONFIG_MAP = {
 
 
 def data_gen(batch_size: int = 4, seq_len: int = 512):
-    input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_current_device())
+    input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_accelerator().get_current_device())
     attention_mask = torch.ones_like(input_ids)
     data = dict(input_ids=input_ids, attention_mask=attention_mask)
     return data
@@ -97,9 +96,9 @@ def print_details_info(outputs, model_config, args, whole_end2end):
         msg += f"Flops: {num_parameters * num_bytes / whole_avg_latency / 1e12:.2f} TFLOPS\n"
 
     if torch.cuda.is_available():
-        msg += f"-------Memory Summary Device:{device_utils.current_device()}-------\n"
-        msg += f"Max memory allocated: {device_utils.max_memory_allocated() / GIGABYTE:.2f} GB\n"
-        msg += f"Max memory reserved: {device_utils.max_memory_reserved() / GIGABYTE:.2f} GB\n"
+        msg += f"-------Memory Summary Device:{get_accelerator().current_device()}-------\n"
+        msg += f"Max memory allocated: {get_accelerator().max_memory_allocated() / GIGABYTE:.2f} GB\n"
+        msg += f"Max memory reserved: {get_accelerator().max_memory_reserved() / GIGABYTE:.2f} GB\n"
 
     print(msg)
 
diff --git a/examples/inference/run_llama_inference.py b/examples/inference/run_llama_inference.py
index 8f85a9363..b5228c64e 100644
--- a/examples/inference/run_llama_inference.py
+++ b/examples/inference/run_llama_inference.py
@@ -5,9 +5,9 @@ import torch.distributed as dist
 from transformers import LlamaForCausalLM, LlamaTokenizer
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.inference import InferenceEngine
 from colossalai.testing import spawn
-from colossalai.utils.device import get_current_device
 
 INPUT_TEXTS = [
     "What is the longest river in the world?",
@@ -57,7 +57,7 @@ def run_inference(args):
     )
 
     inputs = tokenizer(INPUT_TEXTS, return_tensors="pt", padding="longest", max_length=max_input_len, truncation=True)
-    inputs = {k: v.to(get_current_device()) for k, v in inputs.items()}
+    inputs = {k: v.to(get_accelerator().get_current_device()) for k, v in inputs.items()}
     outputs = engine.generate(inputs)
 
     if rank == 0:
diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py
index 563cfa58d..dc6768e58 100644
--- a/examples/language/bert/finetune.py
+++ b/examples/language/bert/finetune.py
@@ -18,11 +18,11 @@ from transformers import (
 )
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.booster import Booster
 from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
 from colossalai.cluster import DistCoordinator
 from colossalai.nn.optimizer import HybridAdam
-from colossalai.utils import get_current_device
 
 # ==============================
 # Prepare Hyperparameters
@@ -59,7 +59,7 @@ def evaluate_model(
         use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
         is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
 
-        accum_loss = torch.zeros(1, device=get_current_device())
+        accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
         for batch in dataloader:
             batch = move_to_cuda(batch)
             labels = batch["labels"]
@@ -88,8 +88,10 @@ def evaluate_model(
                     object_list = [None, None]
                     dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group)
 
-                    metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels)
-                    accum_loss.add_(object_list[1].to(get_current_device()))
+                    metric.add_batch(
+                        predictions=object_list[0].to(get_accelerator().get_current_device()), references=labels
+                    )
+                    accum_loss.add_(object_list[1].to(get_accelerator().get_current_device()))
 
             else:
                 batch = move_to_cuda(batch)
diff --git a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py
index e811e1acb..b35112498 100644
--- a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py
+++ b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py
@@ -7,13 +7,13 @@ from model_zoo import GPTLMLoss, get_gpt2_components
 from torch.utils._pytree import tree_map
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer
 from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
 from colossalai.auto_parallel.offload.solver import NOT_NVML
 from colossalai.fx.profiler import parameter_size
 from colossalai.nn.optimizer import HybridAdam
 from colossalai.testing import spawn
-from colossalai.utils import get_current_device
 
 
 def parse_args():
@@ -41,7 +41,7 @@ def train_gpt(args):
             64,
             8,
         ),
-        device=get_current_device(),
+        device=get_accelerator().get_current_device(),
     )
     criterion = GPTLMLoss()
 
diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py
index 88b76c654..78d090ba2 100644
--- a/examples/language/gpt/gemini/train_gpt_demo.py
+++ b/examples/language/gpt/gemini/train_gpt_demo.py
@@ -12,12 +12,12 @@ from commons.utils import get_data, get_profile_context, get_tflops, get_time_st
 from packaging import version
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.booster import Booster
 from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
 from colossalai.lazy import LazyInitContext
 from colossalai.logging import disable_existing_loggers, get_dist_logger
 from colossalai.nn.optimizer import HybridAdam
-from colossalai.utils import get_current_device
 
 CAI_VERSION = colossalai.__version__
 
@@ -141,7 +141,11 @@ def main():
     criterion = GPTLMLoss()
     torch.manual_seed(123)
     if args.distplan.startswith("CAI"):
-        ctx = LazyInitContext(default_device=get_current_device()) if args.distplan == "CAI_Gemini" else nullcontext()
+        ctx = (
+            LazyInitContext(default_device=get_accelerator().get_current_device())
+            if args.distplan == "CAI_Gemini"
+            else nullcontext()
+        )
         # build GPT model
         with ctx:
             model = model_builder(args.model_type)(checkpoint=True)
diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py
index 62804eff8..eb56ee530 100644
--- a/examples/language/gpt/hybridparallelism/finetune.py
+++ b/examples/language/gpt/hybridparallelism/finetune.py
@@ -13,11 +13,11 @@ from tqdm import tqdm
 from transformers import AutoConfig, GPT2ForSequenceClassification, get_linear_schedule_with_warmup
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.booster import Booster
 from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
 from colossalai.cluster import DistCoordinator
 from colossalai.nn.optimizer import HybridAdam
-from colossalai.utils import get_current_device
 
 # ==============================
 # Prepare Hyperparameters
@@ -54,7 +54,7 @@ def evaluate_model(
         use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
         is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
 
-        accum_loss = torch.zeros(1, device=get_current_device())
+        accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
         for batch in dataloader:
             batch = move_to_cuda(batch)
             labels = batch["labels"]
@@ -83,8 +83,10 @@ def evaluate_model(
                     object_list = [None, None]
                     dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group)
 
-                    metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels)
-                    accum_loss.add_(object_list[1].to(get_current_device()))
+                    metric.add_batch(
+                        predictions=object_list[0].to(get_accelerator().get_current_device()), references=labels
+                    )
+                    accum_loss.add_(object_list[1].to(get_accelerator().get_current_device()))
 
             else:
                 batch = move_to_cuda(batch)
diff --git a/examples/language/gpt/titans/model/embed.py b/examples/language/gpt/titans/model/embed.py
index b2e3f71a5..ec3df50c4 100644
--- a/examples/language/gpt/titans/model/embed.py
+++ b/examples/language/gpt/titans/model/embed.py
@@ -5,6 +5,7 @@ from torch import nn as nn
 from torch.nn import functional as F
 from torch.nn.parameter import Parameter
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.context import ParallelMode, seed
 from colossalai.legacy.core import global_context as gpc
 from colossalai.legacy.nn.layer.base_layer import ParallelLayer
@@ -12,7 +13,6 @@ from colossalai.legacy.nn.layer.parallel_1d._utils import gather_forward_split_b
 from colossalai.legacy.nn.layer.parallel_1d.layers import Linear1D_Row
 from colossalai.legacy.nn.layer.utils import divide
 from colossalai.legacy.registry import LAYERS, LOSSES
-from colossalai.utils import get_current_device
 
 
 class VocabParallelEmbedding(torch.nn.Module):
@@ -96,7 +96,9 @@ class VocabParallelEmbedding(torch.nn.Module):
         if position_ids is not None:
             position_ids = position_ids.view(-1, input_shape[-1])
         if position_ids is None:
-            position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device())
+            position_ids = torch.arange(
+                0, input_shape[-1] + 0, dtype=torch.long, device=get_accelerator().get_current_device()
+            )
             position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
         position_embeddings = self.position_embeddings(position_ids)
 
@@ -194,7 +196,7 @@ class VocabParallelEmbedding1D(torch.nn.Module):
         self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
 
         # Allocate weights and initialize.
-        factory_kwargs = {"device": get_current_device(), "dtype": dtype}
+        factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
         self.weight = Parameter(torch.empty(self.num_embeddings_per_partition, self.embedding_dim, **factory_kwargs))
         init.uniform_(self.weight, -1, 1)
 
@@ -439,7 +441,9 @@ class HiddenParallelEmbedding(torch.nn.Module):
         if position_ids is not None:
             position_ids = position_ids.view(-1, input_shape[-1])
         if position_ids is None:
-            position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device())
+            position_ids = torch.arange(
+                0, input_shape[-1] + 0, dtype=torch.long, device=get_accelerator().get_current_device()
+            )
             position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
         position_embeddings = self.position_embeddings(position_ids)
 
@@ -532,7 +536,7 @@ class HiddenParallelEmbedding1D(torch.nn.Module):
         self._weight = None
 
         # Allocate weights and initialize.
-        factory_kwargs = {"device": get_current_device(), "dtype": dtype}
+        factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
         self.weight = Parameter(torch.empty(num_embeddings, embed_dim_per_partition, **factory_kwargs))
         init.uniform_(self.weight, -1, 1)
 
diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py
index d7a79a022..2f8a76044 100644
--- a/examples/language/llama2/benchmark.py
+++ b/examples/language/llama2/benchmark.py
@@ -13,13 +13,12 @@ from transformers.models.llama.configuration_llama import LlamaConfig
 from transformers.models.llama.modeling_llama import LlamaForCausalLM
 
 import colossalai
-import colossalai.utils.device as device_utils
+from colossalai.accelerator import get_accelerator
 from colossalai.booster import Booster
 from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin
 from colossalai.cluster import DistCoordinator
 from colossalai.lazy import LazyInitContext
 from colossalai.nn.optimizer import HybridAdam
-from colossalai.utils import get_current_device
 
 # ==============================
 # Constants
@@ -166,7 +165,7 @@ def main():
     # Initialize Model and Optimizer
     # ==============================
     init_ctx = (
-        LazyInitContext(default_device=get_current_device())
+        LazyInitContext(default_device=get_accelerator().get_current_device())
         if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
         else nullcontext()
     )
@@ -197,7 +196,9 @@ def main():
     torch.set_default_dtype(torch.bfloat16)
     model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
     torch.set_default_dtype(torch.float)
-    coordinator.print_on_master(f"Booster init max CUDA memory: {device_utils.max_memory_allocated()/1024**2:.2f} MB")
+    coordinator.print_on_master(
+        f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
+    )
     coordinator.print_on_master(
         f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
     )
@@ -223,7 +224,7 @@ def main():
             performance_evaluator.on_step_end(**batch)
 
     performance_evaluator.on_fit_end()
-    coordinator.print_on_master(f"Max CUDA memory usage: {device_utils.max_memory_allocated()/1024**2:.2f} MB")
+    coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")
 
 
 if __name__ == "__main__":
diff --git a/examples/language/llama2/data_utils.py b/examples/language/llama2/data_utils.py
index a438833e1..6b9e8ef28 100644
--- a/examples/language/llama2/data_utils.py
+++ b/examples/language/llama2/data_utils.py
@@ -8,7 +8,7 @@ from torch.distributed import ProcessGroup
 from torch.distributed.distributed_c10d import _get_default_group
 from torch.utils.data import DataLoader, Dataset, DistributedSampler
 
-from colossalai.utils import get_current_device
+from colossalai.accelerator import get_accelerator
 
 
 class StatefulDistributedSampler(DistributedSampler):
@@ -108,7 +108,9 @@ class RandomDataset(Dataset):
     def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000):
         self.num_samples = num_samples
         self.max_length = max_length
-        self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
+        self.input_ids = torch.randint(
+            0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device()
+        )
         self.attention_mask = torch.ones_like(self.input_ids)
 
     def __len__(self):
diff --git a/examples/language/llama2/finetune.py b/examples/language/llama2/finetune.py
index f7708b1a3..66b540076 100644
--- a/examples/language/llama2/finetune.py
+++ b/examples/language/llama2/finetune.py
@@ -21,13 +21,13 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM
 from transformers.models.llama.tokenization_llama import LlamaTokenizer
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.booster import Booster
 from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
 from colossalai.cluster import DistCoordinator
 from colossalai.lazy import LazyInitContext
 from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
 from colossalai.nn.optimizer import HybridAdam
-from colossalai.utils import get_current_device
 
 
 def get_model_numel(model: nn.Module) -> int:
@@ -191,7 +191,9 @@ def main():
     config = LlamaConfig.from_pretrained(args.model_path)
     # use lazy init when using GeminiPlugin
     init_ctx = (
-        LazyInitContext(default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext()
+        LazyInitContext(default_device=get_accelerator().get_current_device())
+        if isinstance(plugin, GeminiPlugin)
+        else nullcontext()
     )
 
     with init_ctx:
diff --git a/examples/language/llama2/performance_evaluator.py b/examples/language/llama2/performance_evaluator.py
index 6b1c92711..c2169a730 100644
--- a/examples/language/llama2/performance_evaluator.py
+++ b/examples/language/llama2/performance_evaluator.py
@@ -5,9 +5,8 @@ import torch
 import torch.distributed as dist
 from torch import Tensor
 
-import colossalai.utils.device as device_utils
+from colossalai.accelerator import get_accelerator
 from colossalai.cluster import DistCoordinator
-from colossalai.utils.device import get_current_device
 
 
 def divide(x: float, y: float) -> float:
@@ -22,7 +21,7 @@ def divide(x: float, y: float) -> float:
 def all_reduce_mean(x: float, world_size: int) -> float:
     if world_size == 1:
         return x
-    tensor = torch.tensor([x], device=get_current_device())
+    tensor = torch.tensor([x], device=get_accelerator().get_current_device())
     dist.all_reduce(tensor)
     tensor = tensor / world_size
     return tensor.item()
@@ -86,13 +85,13 @@ class PerformanceEvaluator:
         self.disable = self.ignore_steps > 0 and step < self.ignore_steps
         if self.disable:
             return
-        device_utils.synchronize()
+        get_accelerator().synchronize()
         self.timer.start()
 
     def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
         if self.disable:
             return
-        device_utils.synchronize()
+        get_accelerator().synchronize()
         self.timer.end()
 
         batch_size, seq_len = input_ids.shape
diff --git a/examples/language/llama2/pretrain.py b/examples/language/llama2/pretrain.py
index bb10f7a00..d32cec2a2 100644
--- a/examples/language/llama2/pretrain.py
+++ b/examples/language/llama2/pretrain.py
@@ -20,13 +20,13 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM
 from transformers.models.llama.tokenization_llama import LlamaTokenizer
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.booster import Booster
 from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
 from colossalai.cluster import DistCoordinator
 from colossalai.lazy import LazyInitContext
 from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
 from colossalai.nn.optimizer import HybridAdam
-from colossalai.utils import get_current_device
 
 MODEL_CONFIGS = {
     "7b": LlamaConfig(max_position_embeddings=4096),
@@ -227,7 +227,9 @@ def main():
     config = MODEL_CONFIGS[args.config]
     # use lazy init when using GeminiPlugin
     init_ctx = (
-        LazyInitContext(default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext()
+        LazyInitContext(default_device=get_accelerator().get_current_device())
+        if isinstance(plugin, GeminiPlugin)
+        else nullcontext()
     )
 
     with init_ctx:
diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py
index 65562b386..03b660ecf 100644
--- a/examples/language/openmoe/benchmark/benchmark_cai.py
+++ b/examples/language/openmoe/benchmark/benchmark_cai.py
@@ -14,6 +14,7 @@ from transformers.models.llama import LlamaConfig
 from utils import PerformanceEvaluator, get_model_numel
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.booster import Booster
 from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
 from colossalai.cluster import DistCoordinator
@@ -21,7 +22,6 @@ from colossalai.moe.layers import apply_load_balance
 from colossalai.moe.manager import MOE_MANAGER
 from colossalai.moe.utils import skip_init
 from colossalai.nn.optimizer import HybridAdam
-from colossalai.utils import get_current_device
 
 
 def move_to_cuda(batch, device):
@@ -64,13 +64,15 @@ class RandomDataset(Dataset):
                 )
                 self.input_ids.append(encode["input_ids"])
                 self.attention_mask.append(encode["attention_mask"])
-            self.input_ids = torch.cat(self.input_ids, dim=0).to(get_current_device())
-            self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_current_device())
+            self.input_ids = torch.cat(self.input_ids, dim=0).to(get_accelerator().get_current_device())
+            self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_accelerator().get_current_device())
             repeat_times = num_samples // self.input_ids.shape[0] + 1
             self.input_ids = self.input_ids.repeat(repeat_times, 1)[:num_samples]
             self.attention_mask = self.attention_mask.repeat(repeat_times, 1)[:num_samples]
         else:
-            self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
+            self.input_ids = torch.randint(
+                0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device()
+            )
             self.attention_mask = torch.ones_like(self.input_ids)
 
     def __len__(self):
diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py
index b08436166..1ae661f54 100644
--- a/examples/language/openmoe/train.py
+++ b/examples/language/openmoe/train.py
@@ -15,6 +15,7 @@ from transformers import T5Tokenizer
 from transformers.models.llama import LlamaConfig
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.booster import Booster
 from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
 from colossalai.cluster import DistCoordinator
@@ -22,7 +23,6 @@ from colossalai.moe.layers import apply_load_balance
 from colossalai.moe.manager import MOE_MANAGER
 from colossalai.moe.utils import skip_init
 from colossalai.nn.optimizer import HybridAdam
-from colossalai.utils import get_current_device
 
 
 def move_to_cuda(batch, device):
@@ -61,7 +61,9 @@ class RandomDataset(Dataset):
     def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None):
         self.num_samples = num_samples
         self.max_length = max_length
-        self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
+        self.input_ids = torch.randint(
+            0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device()
+        )
         self.attention_mask = torch.ones_like(self.input_ids)
 
     def __len__(self):
diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py
index 7af02e24e..4fac7b507 100644
--- a/examples/language/palm/train.py
+++ b/examples/language/palm/train.py
@@ -14,12 +14,12 @@ from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
 from torch.utils.data import DataLoader, Dataset
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.booster import Booster
 from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
 from colossalai.lazy import LazyInitContext
 from colossalai.logging import disable_existing_loggers, get_dist_logger
 from colossalai.nn import HybridAdam
-from colossalai.utils import get_current_device
 
 # constants
 
@@ -159,7 +159,11 @@ if args.distplan == "colossalai":
     logger.info(f"plugin: {plugin}")
     booster = Booster(plugin=plugin, **booster_kwargs)
 
-    ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == "gemini" else nullcontext()
+    ctx = (
+        LazyInitContext(default_device=get_accelerator().get_current_device())
+        if args.plugin == "gemini"
+        else nullcontext()
+    )
 
     with ctx:
         model = PaLM(num_tokens=50304, dim=4096, depth=64)
diff --git a/examples/tutorial/new_api/cifar_resnet/train.py b/examples/tutorial/new_api/cifar_resnet/train.py
index 4407a51c3..a4733126f 100644
--- a/examples/tutorial/new_api/cifar_resnet/train.py
+++ b/examples/tutorial/new_api/cifar_resnet/train.py
@@ -13,12 +13,12 @@ from torch.utils.data import DataLoader
 from tqdm import tqdm
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.booster import Booster
 from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
 from colossalai.booster.plugin.dp_plugin_base import DPPluginBase
 from colossalai.cluster import DistCoordinator
 from colossalai.nn.optimizer import HybridAdam
-from colossalai.utils import get_current_device
 
 # ==============================
 # Prepare Hyperparameters
@@ -53,8 +53,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl
 @torch.no_grad()
 def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float:
     model.eval()
-    correct = torch.zeros(1, dtype=torch.int64, device=get_current_device())
-    total = torch.zeros(1, dtype=torch.int64, device=get_current_device())
+    correct = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device())
+    total = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device())
     for images, labels in test_dataloader:
         images = images.cuda()
         labels = labels.cuda()
diff --git a/examples/tutorial/new_api/cifar_vit/train.py b/examples/tutorial/new_api/cifar_vit/train.py
index 700e4d2e0..ec6c852b5 100644
--- a/examples/tutorial/new_api/cifar_vit/train.py
+++ b/examples/tutorial/new_api/cifar_vit/train.py
@@ -13,13 +13,13 @@ from torch.utils.data import DataLoader
 from tqdm import tqdm
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.booster import Booster
 from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
 from colossalai.booster.plugin.dp_plugin_base import DPPluginBase
 from colossalai.cluster import DistCoordinator
 from colossalai.nn.lr_scheduler import LinearWarmupLR
 from colossalai.nn.optimizer import HybridAdam
-from colossalai.utils import get_current_device
 
 # ==============================
 # Prepare Hyperparameters
@@ -73,8 +73,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl
 @torch.no_grad()
 def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float:
     model.eval()
-    correct = torch.zeros(1, dtype=torch.int64, device=get_current_device())
-    total = torch.zeros(1, dtype=torch.int64, device=get_current_device())
+    correct = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device())
+    total = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device())
     for images, labels in test_dataloader:
         images = images.cuda()
         labels = labels.cuda()
diff --git a/examples/tutorial/new_api/glue_bert/finetune.py b/examples/tutorial/new_api/glue_bert/finetune.py
index 990822c9f..e97c9017f 100644
--- a/examples/tutorial/new_api/glue_bert/finetune.py
+++ b/examples/tutorial/new_api/glue_bert/finetune.py
@@ -12,11 +12,11 @@ from tqdm import tqdm
 from transformers import AutoConfig, BertForSequenceClassification, get_linear_schedule_with_warmup
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.booster import Booster
 from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
 from colossalai.cluster import DistCoordinator
 from colossalai.nn.optimizer import HybridAdam
-from colossalai.utils import get_current_device
 
 # ==============================
 # Prepare Hyperparameters
@@ -45,7 +45,7 @@ def evaluate(
     model.eval()
 
     def evaluate_subset(dataloader: DataLoader):
-        accum_loss = torch.zeros(1, device=get_current_device())
+        accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
         for batch in dataloader:
             batch = move_to_cuda(batch)
             outputs = model(**batch)
diff --git a/examples/tutorial/opt/opt/run_clm.py b/examples/tutorial/opt/opt/run_clm.py
index 9bd23ffc8..3f0d04879 100755
--- a/examples/tutorial/opt/opt/run_clm.py
+++ b/examples/tutorial/opt/opt/run_clm.py
@@ -51,13 +51,13 @@ from transformers import (
 from transformers.utils.versions import require_version
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.context import ParallelMode
 from colossalai.legacy.core import global_context as gpc
 from colossalai.legacy.tensor import ProcessGroup
 from colossalai.legacy.utils import get_dataloader
 from colossalai.logging import disable_existing_loggers, get_dist_logger
 from colossalai.nn.optimizer import HybridAdam
-from colossalai.utils import get_current_device
 from colossalai.zero import GeminiOptimizer
 
 require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
@@ -249,9 +249,9 @@ def parse_args():
 
 
 def colo_memory_cap(size_in_GB):
-    from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
+    from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction
 
-    cuda_capacity = colo_device_memory_capacity(get_current_device())
+    cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
     if size_in_GB * (1024**3) < cuda_capacity:
         colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
         print("Using {} GB of GPU memory".format(size_in_GB))
@@ -265,7 +265,9 @@ class DummyDataloader:
         self.vocab_size = vocab_size
 
     def generate(self):
-        input_ids = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len), device=get_current_device())
+        input_ids = torch.randint(
+            0, self.vocab_size, (self.batch_size, self.seq_len), device=get_accelerator().get_current_device()
+        )
         attention_mask = torch.ones_like(input_ids)
         return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": input_ids}
 
@@ -390,7 +392,7 @@ def main():
     if args.init_in_cpu:
         init_dev = torch.device("cpu")
     else:
-        init_dev = get_current_device()
+        init_dev = get_accelerator().get_current_device()
 
     cai_version = colossalai.__version__
     logger.info(f"using Colossal-AI version {cai_version}")
@@ -439,7 +441,9 @@ def main():
         except ImportError:
             # this works for unreleased main branch, and this may be released on 0.2.9
             from colossalai.zero import GeminiDDP
-        model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True)
+        model = GeminiDDP(
+            model, device=get_accelerator().get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True
+        )
     elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
         from colossalai.gemini import ChunkManager, GeminiManager
 
diff --git a/tests/test_auto_parallel/test_offload/test_perf.py b/tests/test_auto_parallel/test_offload/test_perf.py
index 2c8b260e6..373ba28b8 100644
--- a/tests/test_auto_parallel/test_offload/test_perf.py
+++ b/tests/test_auto_parallel/test_offload/test_perf.py
@@ -5,13 +5,13 @@ import torch
 from torch.utils._pytree import tree_map
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer
 from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
 from colossalai.auto_parallel.offload.solver import NOT_NVML
 from colossalai.fx.profiler import parameter_size
 from colossalai.nn.optimizer import HybridAdam
 from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
-from colossalai.utils import get_current_device
 from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
 from tests.test_auto_parallel.test_offload.model_utils import *
 from tests.test_tensor.common_utils import set_seed
@@ -31,7 +31,7 @@ def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str):
             64,
             8,
         ),
-        device=get_current_device(),
+        device=get_accelerator().get_current_device(),
     )
     criterion = LMLoss()
 
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py
index aba746f19..d57717326 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py
@@ -10,12 +10,12 @@ try:
 except:
     NO_CODEGEN = True
 
+from colossalai.accelerator import get_accelerator
 from colossalai.device.device_mesh import DeviceMesh
 from colossalai.initialize import launch
 from colossalai.logging import disable_existing_loggers
 from colossalai.nn.optimizer import HybridAdam
 from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn
-from colossalai.utils import get_current_device
 from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
 
 
@@ -72,7 +72,11 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
         print("=" * msg_length)
 
     gemini_config = dict(
-        strict_ddp_mode=False, device=get_current_device(), placement_policy="cpu", pin_memory=True, search_range_m=128
+        strict_ddp_mode=False,
+        device=get_accelerator().get_current_device(),
+        placement_policy="cpu",
+        pin_memory=True,
+        search_range_m=128,
     )
 
     gm = zero_model_wrapper(gm, zero_stage=3, gemini_config=gemini_config)
diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py
index 3eaaf882c..490c015a8 100644
--- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py
+++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py
@@ -5,7 +5,7 @@ import torch.distributed as dist
 from torch.optim import Adam
 
 import colossalai
-import colossalai.utils.device as device_utils
+from colossalai.accelerator import get_accelerator
 from colossalai.booster import Booster
 from colossalai.booster.plugin import LowLevelZeroPlugin
 
@@ -22,7 +22,7 @@ _STUCK_MODELS = ["transformers_albert_for_multiple_choice"]
 
 
 def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
-    device = device_utils.get_current_device()
+    device = get_accelerator().get_current_device()
     try:
         plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5)
         booster = Booster(plugin=plugin)
@@ -69,7 +69,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
             continue
         err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn)
 
-        device_utils.empty_cache()
+        get_accelerator().empty_cache()
 
         if err is None:
             passed_models.append(name)
diff --git a/tests/test_legacy/test_comm/test_comm.py b/tests/test_legacy/test_comm/test_comm.py
index 7d2c81972..079022e93 100644
--- a/tests/test_legacy/test_comm/test_comm.py
+++ b/tests/test_legacy/test_comm/test_comm.py
@@ -2,12 +2,12 @@ import pytest
 import torch
 import torch.distributed as dist
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.communication import all_gather, all_reduce, reduce_scatter
 from colossalai.legacy.context import ParallelMode
 from colossalai.legacy.core import global_context as gpc
 from colossalai.legacy.initialize import launch
 from colossalai.testing import rerun_if_address_is_in_use, spawn
-from colossalai.utils import get_current_device
 
 CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1)))
 
@@ -16,7 +16,7 @@ SIZE = 8
 
 def check_all_gather():
     tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)])
-    tensor = tensor.to(get_current_device())
+    tensor = tensor.to(get_accelerator().get_current_device())
     print("Before:   Rank {0} - {1}".format(dist.get_rank(), tensor))
     tensor, op = all_gather(tensor, 0, ParallelMode.GLOBAL, async_op=True)
     print("After:    Rank {0} - {1}".format(dist.get_rank(), tensor))
@@ -27,7 +27,7 @@ def check_all_gather():
 
 def check_reduce_scatter():
     tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)])
-    tensor = tensor.to(get_current_device())
+    tensor = tensor.to(get_accelerator().get_current_device())
     print("Before:   Rank {0} - {1}".format(dist.get_rank(), tensor))
     tensor, op = reduce_scatter(tensor, 0, ParallelMode.GLOBAL, async_op=True)
     print("After:    Rank {0} - {1}".format(dist.get_rank(), tensor))
@@ -38,7 +38,7 @@ def check_reduce_scatter():
 
 def check_all_reduce():
     tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)])
-    tensor = tensor.to(get_current_device())
+    tensor = tensor.to(get_accelerator().get_current_device())
     print("Before:   Rank {0} - {1}".format(dist.get_rank(), tensor))
     tensor, op = all_reduce(tensor, ParallelMode.GLOBAL, async_op=True)
     print("After:    Rank {0} - {1}".format(dist.get_rank(), tensor))
diff --git a/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py b/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py
index 8a9a73d65..f09df9253 100644
--- a/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py
+++ b/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py
@@ -2,6 +2,7 @@ import torch
 import torch.distributed as dist
 from torch.nn import Parameter
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.context.parallel_mode import ParallelMode
 from colossalai.legacy.core import global_context as gpc
 from colossalai.legacy.global_variables import tensor_parallel_env as env
@@ -16,13 +17,12 @@ from colossalai.legacy.nn import (
     VocabParallelEmbedding1D,
 )
 from colossalai.legacy.utils import print_rank_0
-from colossalai.utils import get_current_device
 
 from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
 
 
 def check_linear_col():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
     INPUT_SIZE = HIDDEN_SIZE
     OUTPUT_SIZE = 2 * HIDDEN_SIZE
@@ -68,7 +68,7 @@ def check_linear_col():
     print_rank_0("linear_col forward: pass")
 
     grad_shape = C_master.shape
-    grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
+    grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
     dist.broadcast(grad_master, src=0)
     grad = torch.chunk(grad_master, DEPTH, dim=-1)[i]
     grad = grad.clone()
@@ -91,7 +91,7 @@ def check_linear_col():
 
 
 def check_linear_row():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
     INPUT_SIZE = HIDDEN_SIZE
     OUTPUT_SIZE = 2 * HIDDEN_SIZE
@@ -137,7 +137,7 @@ def check_linear_row():
     print_rank_0("linear_row forward: pass")
 
     grad_shape = C_master.shape
-    grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
+    grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
     dist.broadcast(grad_master, src=0)
     grad = grad_master.clone()
     out.backward(grad)
@@ -159,7 +159,7 @@ def check_linear_row():
 
 
 def check_embed():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
 
     i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
@@ -201,7 +201,7 @@ def check_embed():
 
 
 def check_vocab_parallel_embed():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
 
     i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
@@ -243,7 +243,7 @@ def check_vocab_parallel_embed():
 
 
 def check_classifier_no_given_weight():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
 
     i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
@@ -309,7 +309,7 @@ def check_classifier_no_given_weight():
 
 
 def check_vocab_parallel_classifier_no_given_weight():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
 
     i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
@@ -369,7 +369,7 @@ def check_vocab_parallel_classifier_no_given_weight():
 
 
 def check_classifier_given_embed_weight():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
 
     i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
@@ -420,7 +420,7 @@ def check_classifier_given_embed_weight():
 
 
 def check_vocab_parallel_classifier_given_embed_weight():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
 
     i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
@@ -472,7 +472,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
 
 
 def check_vocab_parallel_loss():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
 
     i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
@@ -508,7 +508,7 @@ def check_vocab_parallel_loss():
 
 @torch.no_grad()
 def check_linear_row_stream_inference():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
     INPUT_SIZE = HIDDEN_SIZE
     OUTPUT_SIZE = 2 * HIDDEN_SIZE
diff --git a/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py b/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py
index 0bbc72eca..78bd407b9 100644
--- a/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py
+++ b/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py
@@ -1,5 +1,6 @@
 import torch
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.context.parallel_mode import ParallelMode
 from colossalai.legacy.core import global_context as gpc
 from colossalai.legacy.nn import (
@@ -16,13 +17,12 @@ from colossalai.legacy.nn import (
     VocabParallelEmbedding2D,
 )
 from colossalai.legacy.utils import print_rank_0
-from colossalai.utils import get_current_device
 
 from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
 
 
 def check_linear():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
     INPUT_SIZE = HIDDEN_SIZE
     OUTPUT_SIZE = HIDDEN_SIZE
@@ -74,7 +74,7 @@ def check_linear():
     print_rank_0("linear forward: pass")
 
     grad_shape = C_master.shape
-    grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
+    grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
     torch.distributed.broadcast(grad_master, src=0)
     grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
     grad = torch.chunk(grad, DEPTH, dim=-1)[j]
@@ -103,7 +103,7 @@ def check_linear():
 
 
 def check_layernorm():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
     INPUT_SIZE = HIDDEN_SIZE
     EPS = 1e-12
@@ -139,7 +139,7 @@ def check_layernorm():
     print_rank_0("layer norm forward: pass")
 
     grad_shape = C_master.shape
-    grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
+    grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
     torch.distributed.broadcast(grad_master, src=0)
     grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
     grad = torch.chunk(grad, DEPTH, dim=-1)[j]
@@ -154,7 +154,7 @@ def check_layernorm():
 
 
 def check_embed():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
     j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
     i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
@@ -201,7 +201,7 @@ def check_embed():
 
 
 def check_patch_embed():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
     j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
     i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
@@ -274,7 +274,7 @@ def check_patch_embed():
 
 
 def check_vocab_parallel_embed():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
     j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
     i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
@@ -321,7 +321,7 @@ def check_vocab_parallel_embed():
 
 
 def check_classifier_no_given_weight():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
     INPUT_SIZE = HIDDEN_SIZE
     OUTPUT_SIZE = NUM_CLASSES
@@ -371,7 +371,7 @@ def check_classifier_no_given_weight():
     print_rank_0("classifier (no given weight) forward: pass")
 
     grad_shape = C_master.shape
-    grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
+    grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
     torch.distributed.broadcast(grad_master, src=0)
     grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
     # grad = torch.chunk(grad, DEPTH, dim=-1)[j]
@@ -399,7 +399,7 @@ def check_classifier_no_given_weight():
 
 
 def check_vocab_parallel_classifier_no_given_weight():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
 
     j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
@@ -467,7 +467,7 @@ def check_vocab_parallel_classifier_no_given_weight():
 
 
 def check_classifier_given_embed_weight():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
 
     j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
@@ -519,7 +519,7 @@ def check_classifier_given_embed_weight():
 
 
 def check_vocab_parallel_classifier_given_embed_weight():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
 
     j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
@@ -573,7 +573,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
 
 
 def check_loss():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
 
     gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
@@ -608,7 +608,7 @@ def check_loss():
 
 
 def check_vocab_parallel_loss():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
 
     j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
@@ -645,7 +645,7 @@ def check_vocab_parallel_loss():
 
 
 # def check_attention():
-#     device = get_current_device()
+#     device = get_accelerator().get_current_device()
 #     dtype = torch.float32
 #     INPUT_SIZE = HIDDEN_SIZE
 #     NUM_ATTENTION_HEADS = 2
@@ -683,7 +683,7 @@ def check_vocab_parallel_loss():
 #     print_rank_0('self attention backward: pass')
 
 # def check_mlp():
-#     device = get_current_device()
+#     device = get_accelerator().get_current_device()
 #     dtype = torch.float32
 #     INPUT_SIZE = HIDDEN_SIZE
 
@@ -716,7 +716,7 @@ def check_vocab_parallel_loss():
 #     print_rank_0('mlp backward: pass')
 
 # def check_transformerlayer():
-#     device = get_current_device()
+#     device = get_accelerator().get_current_device()
 #     dtype = torch.float32
 #     INPUT_SIZE = HIDDEN_SIZE
 #     NUM_ATTENTION_HEADS = 2
diff --git a/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py b/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py
index 9c126cefe..4506cfee6 100644
--- a/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py
+++ b/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py
@@ -3,11 +3,11 @@
 
 import torch
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.context.parallel_mode import ParallelMode
 from colossalai.legacy.core import global_context as gpc
 from colossalai.legacy.nn.layer.parallel_2d._operation import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D
 from colossalai.legacy.utils import print_rank_0
-from colossalai.utils import get_current_device
 
 from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, SEQ_LENGTH, check_equal
 
@@ -27,7 +27,7 @@ def check_AB():
     i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
 
     A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
-    A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
+    A_master = torch.randn(A_shape, dtype=dtype, device=get_accelerator().get_current_device())
     torch.distributed.broadcast(A_master, src=0)
     A = torch.chunk(A_master, DEPTH, dim=0)[i]
     A = torch.chunk(A, DEPTH, dim=-1)[j]
@@ -35,7 +35,7 @@ def check_AB():
     A.requires_grad = True
 
     B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
-    B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device())
+    B_master = torch.randn(B_shape, dtype=dtype, device=get_accelerator().get_current_device())
     torch.distributed.broadcast(B_master, src=0)
     B = torch.chunk(B_master, DEPTH, dim=0)[i]
     B = torch.chunk(B, DEPTH, dim=-1)[j]
@@ -72,7 +72,7 @@ def check_AB():
     print_rank_0("AB forward: pass")
 
     grad_shape = C_master.shape
-    grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
+    grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
     torch.distributed.broadcast(grad_master, src=0)
     grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
     grad = torch.chunk(grad, DEPTH, dim=-1)[j]
@@ -105,7 +105,7 @@ def check_ABT():
     tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
 
     dtype = torch.float
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
 
     j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
     i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
@@ -184,7 +184,7 @@ def check_ATB():
     )
     tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
 
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float
 
     j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
diff --git a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py
index 283e7f683..914607614 100644
--- a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py
+++ b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py
@@ -1,6 +1,7 @@
 import torch
 from torch.nn import Parameter
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.context.parallel_mode import ParallelMode
 from colossalai.legacy.core import global_context as gpc
 from colossalai.legacy.nn import (
@@ -17,13 +18,12 @@ from colossalai.legacy.nn import (
     VocabParallelEmbedding2p5D,
 )
 from colossalai.legacy.utils import print_rank_0
-from colossalai.utils import get_current_device
 
 from .common import *
 
 
 def check_linear():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
     INPUT_SIZE = HIDDEN_SIZE
     OUTPUT_SIZE = 2 * HIDDEN_SIZE
@@ -76,7 +76,7 @@ def check_linear():
     print_rank_0("linear forward: pass")
 
     grad_shape = C_master.shape
-    grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
+    grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
     torch.distributed.broadcast(grad_master, src=0)
     grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
     grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
@@ -104,7 +104,7 @@ def check_linear():
 
 
 def check_layernorm():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
     INPUT_SIZE = HIDDEN_SIZE
     EPS = 1e-12
@@ -141,7 +141,7 @@ def check_layernorm():
     print_rank_0("layer norm forward: pass")
 
     grad_shape = C_master.shape
-    grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
+    grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
     torch.distributed.broadcast(grad_master, src=0)
     grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
     grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
@@ -156,7 +156,7 @@ def check_layernorm():
 
 
 def check_embed():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
     i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
     j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
@@ -204,7 +204,7 @@ def check_embed():
 
 
 def check_patch_embed():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
     i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
     j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
@@ -278,7 +278,7 @@ def check_patch_embed():
 
 
 def check_vocab_parallel_embed():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
     i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
     j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
@@ -326,7 +326,7 @@ def check_vocab_parallel_embed():
 
 
 def check_classifier_no_given_weight():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
     INPUT_SIZE = HIDDEN_SIZE
     OUTPUT_SIZE = NUM_CLASSES
@@ -377,7 +377,7 @@ def check_classifier_no_given_weight():
     print_rank_0("classifier (no given weight) forward: pass")
 
     grad_shape = C_master.shape
-    grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
+    grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
     torch.distributed.broadcast(grad_master, src=0)
     grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
     # grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
@@ -405,7 +405,7 @@ def check_classifier_no_given_weight():
 
 
 def check_vocab_parallel_classifier_no_given_weight():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
     i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
     j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
@@ -472,7 +472,7 @@ def check_vocab_parallel_classifier_no_given_weight():
 
 
 def check_classifier_given_embed_weight():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
     i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
     j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
@@ -524,7 +524,7 @@ def check_classifier_given_embed_weight():
 
 
 def check_vocab_parallel_classifier_given_embed_weight():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
     i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
     j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
@@ -578,7 +578,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
 
 
 def check_loss():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
     i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
     gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
@@ -613,7 +613,7 @@ def check_loss():
 
 
 def check_vocab_parallel_loss():
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
     i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
     j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
@@ -650,7 +650,7 @@ def check_vocab_parallel_loss():
 
 
 # def check_attention():
-#     device = get_current_device()
+#     device = get_accelerator().get_current_device()
 #     dtype = torch.float32
 #     INPUT_SIZE = HIDDEN_SIZE
 #     NUM_ATTENTION_HEADS = 2
@@ -689,7 +689,7 @@ def check_vocab_parallel_loss():
 #     print_rank_0('self attention backward: pass')
 
 # def check_mlp():
-#     device = get_current_device()
+#     device = get_accelerator().get_current_device()
 #     dtype = torch.float32
 #     INPUT_SIZE = HIDDEN_SIZE
 
@@ -725,7 +725,7 @@ def check_vocab_parallel_loss():
 #     print_rank_0('mlp backward: pass')
 
 # def check_transformerlayer():
-#     device = get_current_device()
+#     device = get_accelerator().get_current_device()
 #     dtype = torch.float32
 #     INPUT_SIZE = HIDDEN_SIZE
 #     NUM_ATTENTION_HEADS = 2
diff --git a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py
index 992bd6107..91a15c81d 100644
--- a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py
+++ b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py
@@ -1,10 +1,10 @@
 import torch
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.context import ParallelMode
 from colossalai.legacy.core import global_context as gpc
 from colossalai.legacy.nn.layer.parallel_2p5d._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D
 from colossalai.legacy.utils import print_rank_0
-from colossalai.utils import get_current_device
 
 from .common import *
 
@@ -25,7 +25,7 @@ def check_AB():
     k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
 
     A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
-    A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
+    A_master = torch.randn(A_shape, dtype=dtype, device=get_accelerator().get_current_device())
     torch.distributed.broadcast(A_master, src=0)
     A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
     A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
@@ -33,7 +33,7 @@ def check_AB():
     A.requires_grad = True
 
     B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
-    B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device())
+    B_master = torch.randn(B_shape, dtype=dtype, device=get_accelerator().get_current_device())
     torch.distributed.broadcast(B_master, src=0)
     B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i]
     B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j]
@@ -70,7 +70,7 @@ def check_AB():
     print_rank_0("AB forward: pass")
 
     grad_shape = C_master.shape
-    grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
+    grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
     torch.distributed.broadcast(grad_master, src=0)
     grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
     grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
@@ -103,7 +103,7 @@ def check_ABT():
     tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
 
     dtype = torch.float
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
 
     i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
     j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
@@ -184,7 +184,7 @@ def check_ATB():
     )
     tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
 
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float
 
     i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
diff --git a/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py b/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py
index a4a4ae9a5..f9f19a17b 100644
--- a/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py
+++ b/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py
@@ -5,6 +5,7 @@ import time
 
 import torch
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
 from colossalai.legacy.core import global_context
 from colossalai.legacy.nn import (
@@ -23,7 +24,6 @@ from colossalai.legacy.nn import (
 from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
 from colossalai.legacy.utils import print_rank_0
 from colossalai.logging import get_dist_logger
-from colossalai.utils import get_current_device
 
 from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
 
@@ -31,7 +31,7 @@ from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_L
 def check_linear():
     rank = torch.distributed.get_rank()
     logger = get_dist_logger()
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     INPUT_SIZE = HIDDEN_SIZE
     OUTPUT_SIZE = 2 * HIDDEN_SIZE
 
@@ -84,7 +84,7 @@ def check_linear():
     logger.info("Rank {} linear forward: {}".format(rank, check_equal(out, C)))
 
     grad_shape = C_master.shape
-    grad_master = torch.randn(grad_shape, device=get_current_device())
+    grad_master = torch.randn(grad_shape, device=get_accelerator().get_current_device())
     torch.distributed.broadcast(grad_master, src=0)
     grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
     grad = torch.chunk(grad, DEPTH, dim=-1)[j]
@@ -119,7 +119,7 @@ def check_linear():
 def check_layernorm():
     rank = torch.distributed.get_rank()
     logger = get_dist_logger()
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     INPUT_SIZE = HIDDEN_SIZE
 
     input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
@@ -206,7 +206,7 @@ def check_layernorm():
 def check_classifier_no_given_weight():
     rank = torch.distributed.get_rank()
     logger = get_dist_logger()
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     INPUT_SIZE = HIDDEN_SIZE
 
     input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
@@ -258,7 +258,7 @@ def check_classifier_no_given_weight():
     logger.info("Rank {} classifier (no given weight) forward: {}".format(rank, check_equal(out, C)))
 
     grad_shape = C_master.shape
-    grad_master = torch.randn(grad_shape, device=get_current_device())
+    grad_master = torch.randn(grad_shape, device=get_accelerator().get_current_device())
     torch.distributed.broadcast(grad_master, src=0)
     grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
     grad = torch.chunk(grad, DEPTH, dim=0)[j]
@@ -306,7 +306,7 @@ def check_classifier_no_given_weight():
 def check_vocab_parallel_classifier_no_given_weight():
     rank = torch.distributed.get_rank()
     logger = get_dist_logger()
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     INPUT_SIZE = HIDDEN_SIZE
 
     input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
@@ -413,7 +413,7 @@ def check_vocab_parallel_classifier_no_given_weight():
 def check_classifier_given_embed_weight():
     rank = torch.distributed.get_rank()
     logger = get_dist_logger()
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     dtype = torch.float32
 
     input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
@@ -463,7 +463,7 @@ def check_classifier_given_embed_weight():
     logger.info("Rank {} classifier (given embed weight) forward: {}".format(rank, check_equal(out, C)))
 
     grad_shape = C_master.shape
-    grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
+    grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
     torch.distributed.broadcast(grad_master, src=0)
     grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
     grad = torch.chunk(grad, DEPTH, dim=0)[j]
@@ -497,7 +497,7 @@ def check_classifier_given_embed_weight():
 def check_vocab_parallel_classifier_given_embed_weight():
     rank = torch.distributed.get_rank()
     logger = get_dist_logger()
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
 
     input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
     weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
@@ -580,7 +580,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
 
 def check_patch_embed():
     rank = torch.distributed.get_rank()
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     logger = get_dist_logger()
     torch.float32
 
@@ -678,7 +678,7 @@ def check_patch_embed():
 
 def check_embed():
     rank = torch.distributed.get_rank()
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     logger = get_dist_logger()
     torch.float32
 
@@ -746,7 +746,7 @@ def check_embed():
 
 def check_vocab_parallel_embed():
     rank = torch.distributed.get_rank()
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     logger = get_dist_logger()
     torch.float32
 
@@ -823,7 +823,7 @@ def check_vocab_parallel_embed():
 def check_loss():
     rank = torch.distributed.get_rank()
     logger = get_dist_logger()
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
 
     input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
     weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
@@ -876,7 +876,7 @@ def check_loss():
 def check_vocab_parallel_loss():
     rank = torch.distributed.get_rank()
     logger = get_dist_logger()
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     torch.float32
 
     input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
diff --git a/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py b/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py
index aa4d5d6ce..f4ad0d6d1 100644
--- a/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py
+++ b/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py
@@ -1,9 +1,9 @@
 import torch
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.context import ParallelMode
 from colossalai.legacy.core import global_context as gpc
 from colossalai.legacy.nn import TransformerSelfAttentionRing
-from colossalai.utils import get_current_device
 
 
 def check_selfattention():
@@ -13,10 +13,10 @@ def check_selfattention():
     HIDDEN_SIZE = 16
 
     layer = TransformerSelfAttentionRing(16, 8, 8, 0.1)
-    layer = layer.to(get_current_device())
+    layer = layer.to(get_accelerator().get_current_device())
 
-    hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_current_device())
+    hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_accelerator().get_current_device())
     attention_mask = torch.randint(low=0, high=2, size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to(
-        get_current_device()
+        get_accelerator().get_current_device()
     )
     layer(hidden_states, attention_mask)
diff --git a/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py
index a5a2d3857..cab111358 100644
--- a/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py
+++ b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py
@@ -5,6 +5,7 @@ import pytest
 import torch
 import torch.distributed as dist
 
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.communication import (
     recv_backward,
     recv_forward,
@@ -18,7 +19,6 @@ from colossalai.legacy.core import global_context as gpc
 from colossalai.legacy.initialize import launch
 from colossalai.logging import get_dist_logger
 from colossalai.testing import rerun_if_address_is_in_use, spawn
-from colossalai.utils import get_current_device
 
 BATCH_SIZE = 4
 SEQ_LENGTH = 2
@@ -73,7 +73,7 @@ def check_forward_backward(output_tensor, output_grad, rank, logger):
 
 def check_comm(size, rank, prev_rank, next_rank, logger):
     dtype = torch.float32
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
     grad_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
     tensor = torch.randn(tensor_shape, dtype=dtype, device=device)
diff --git a/tests/test_legacy/test_utils/test_memory.py b/tests/test_legacy/test_utils/test_memory.py
index 9df7cf75a..4993df4f3 100644
--- a/tests/test_legacy/test_utils/test_memory.py
+++ b/tests/test_legacy/test_utils/test_memory.py
@@ -1,15 +1,15 @@
 import pytest
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
 from colossalai.testing import spawn
-from colossalai.utils.device import get_current_device
 
 
 def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity():
-    frac1 = colo_device_memory_capacity(get_current_device())
+    frac1 = colo_device_memory_capacity(get_accelerator().get_current_device())
     colo_set_process_memory_fraction(0.5)
-    frac2 = colo_device_memory_capacity(get_current_device())
+    frac2 = colo_device_memory_capacity(get_accelerator().get_current_device())
     assert frac2 * 2 == frac1
 
 
diff --git a/tests/test_legacy/test_utils/test_norm_gradient_clipping.py b/tests/test_legacy/test_utils/test_norm_gradient_clipping.py
index b5f2be705..9975cc04f 100644
--- a/tests/test_legacy/test_utils/test_norm_gradient_clipping.py
+++ b/tests/test_legacy/test_utils/test_norm_gradient_clipping.py
@@ -4,12 +4,12 @@ from torch.nn.parameter import Parameter
 from torch.nn.utils import clip_grad_norm_
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.tensor import ColoTensorSpec, ProcessGroup, distspec
 from colossalai.legacy.utils.common import clip_grad_norm
 from colossalai.logging import disable_existing_loggers
 from colossalai.tensor.colo_parameter import ColoParameter
 from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
-from colossalai.utils import get_current_device
 
 
 def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8):
@@ -36,7 +36,7 @@ def check_grad_equal(p: Parameter, colo_p: ColoParameter) -> None:
 @parameterize("norm_type", [2.0, 3.0, float("inf")])
 def run_grad_clip_norm(world_size: int, dtype: torch.dtype, device: str, norm_type: float):
     print(f"{world_size}, {dtype}, {device}, {norm_type}")
-    cuda_device = get_current_device()
+    cuda_device = get_accelerator().get_current_device()
     devices = [cuda_device] * 4
     if device == "cpu":
         devices = [torch.device("cpu")] * 4
diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py
index 3fac62472..a349bc5a9 100644
--- a/tests/test_moe/test_grad_handler.py
+++ b/tests/test_moe/test_grad_handler.py
@@ -4,10 +4,10 @@ import torch.distributed as dist
 import torch.nn as nn
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.moe import SparseMLP
 from colossalai.moe.manager import MOE_MANAGER
 from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
-from colossalai.utils import get_current_device
 from tests.test_moe.moe_utils import MoeGradientHandler
 
 BATCH_SIZE = 4
@@ -38,7 +38,7 @@ def run_test(rank, world_size, port):
         layer_list.append(moe_layer)
 
     model = nn.ModuleList(layer_list)
-    model = model.to(get_current_device())
+    model = model.to(get_accelerator().get_current_device())
     dist_dict = MOE_MANAGER.parallel_info_dict
     assert_equal_in_group(layer_list[0].experts.wi.data, dist_dict[1].dp_group)
     assert_equal_in_group(layer_list[0].experts.wo.data, dist_dict[1].dp_group)
@@ -52,7 +52,7 @@ def run_test(rank, world_size, port):
 
     rank = dist.get_rank()
     torch.cuda.manual_seed(78 + rank)
-    data = torch.randn(BATCH_SIZE, DIM, device=get_current_device())
+    data = torch.randn(BATCH_SIZE, DIM, device=get_accelerator().get_current_device())
     grad = torch.randn_like(data)
 
     MOE_MANAGER.reset_loss()
diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py
index 255ec7444..62d61a3d4 100644
--- a/tests/test_moe/test_kernel.py
+++ b/tests/test_moe/test_kernel.py
@@ -3,10 +3,10 @@ import torch
 import torch.distributed as dist
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.moe import SparseMLP
 from colossalai.moe.manager import MOE_MANAGER
 from colossalai.testing import rerun_if_address_is_in_use, spawn
-from colossalai.utils import get_current_device
 
 BATCH_SIZE = 4
 NUM_EXPERTS = 4
@@ -28,7 +28,9 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
     torch.manual_seed(rs + local_rank)  # set each process has different random seed
 
     # get randomized data
-    tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
+    tokens = torch.randn(
+        BATCH_SIZE, hidden_size, dtype=data_type, device=get_accelerator().get_current_device(), requires_grad=True
+    )
 
     layer = SparseMLP(
         hidden_size=hidden_size,
@@ -37,7 +39,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
         router_top_k=topk,
         router_capacity_factor_train=1.0,
     )
-    layer = layer.to(get_current_device())
+    layer = layer.to(get_accelerator().get_current_device())
     if data_type == torch.float16:
         layer = layer.half()
 
@@ -45,7 +47,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
     layer.enable_kernel = False
     old_out = layer(tokens)
     ech = old_out.shape
-    grad = torch.randn(ech, device=get_current_device())
+    grad = torch.randn(ech, device=get_accelerator().get_current_device())
     old_out.backward(grad)  # get gradient
 
     # save all results
diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py
index bd1103df3..8f51e1663 100644
--- a/tests/test_moe/test_moe_checkpoint.py
+++ b/tests/test_moe/test_moe_checkpoint.py
@@ -9,11 +9,11 @@ import torch.distributed as dist
 from transformers.models.llama import LlamaConfig
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.booster import Booster
 from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
 from colossalai.moe.manager import MOE_MANAGER
 from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
-from colossalai.utils import get_current_device
 
 sys.path.append(
     os.path.join(
@@ -28,7 +28,7 @@ OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenM
 
 
 def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):
-    input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device())
+    input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_accelerator().get_current_device())
     attention_mask = torch.ones_like(input_ids)
     return {
         "input_ids": input_ids,
diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py
index f87d4c792..74feeeb59 100644
--- a/tests/test_moe/test_moe_ep_tp.py
+++ b/tests/test_moe/test_moe_ep_tp.py
@@ -7,12 +7,12 @@ import torch
 import torch.distributed as dist
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.moe import SparseMLP
 from colossalai.moe.manager import MOE_MANAGER
 from colossalai.moe.utils import sync_moe_model_param
 from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
 from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
-from colossalai.utils import get_current_device
 from tests.test_moe.moe_utils import MoeGradientHandler
 
 
@@ -23,8 +23,9 @@ def sync_tp_from_local(tp_model: SparseMLP, local_model: SparseMLP, assert_grad_
         tp_model (MoeModule)
         local_model (MoeModule)
     """
-    for (tp_name, tp_param), (local_name, local_param) in \
-            zip(tp_model.named_parameters(), local_model.named_parameters()):
+    for (tp_name, tp_param), (local_name, local_param) in zip(
+        tp_model.named_parameters(), local_model.named_parameters()
+    ):
         assert tp_name == local_name
         if not is_moe_tensor(tp_param):
             if assert_grad_flag:
@@ -54,8 +55,7 @@ def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag:
         tp_model (MoeModule)
         ep_model (MoeModule)
     """
-    for (tp_name, tp_param), (ep_name, ep_param) in \
-            zip(tp_model.named_parameters(), ep_model.named_parameters()):
+    for (tp_name, tp_param), (ep_name, ep_param) in zip(tp_model.named_parameters(), ep_model.named_parameters()):
         assert tp_name == ep_name
         if not is_moe_tensor(tp_param):
             if assert_grad_flag:
@@ -97,8 +97,9 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_
         local_model (MoeModule)
         ep_model (MoeModule)
     """
-    for (local_name, local_param), (ep_name, ep_param) in \
-            zip(local_model.named_parameters(), ep_model.named_parameters()):
+    for (local_name, local_param), (ep_name, ep_param) in zip(
+        local_model.named_parameters(), ep_model.named_parameters()
+    ):
         assert local_name == ep_name
         if "experts" not in local_name:
             if assert_grad_flag:
@@ -141,14 +142,14 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
         num_experts=num_experts,
         hidden_size=dim,
         intermediate_size=dim * 2,
-        enable_hierarchical_comm=enable_hierarchical_comm
+        enable_hierarchical_comm=enable_hierarchical_comm,
     )
     MOE_MANAGER.__init__()
     MOE_MANAGER.setup(parallel="TP")
     tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
-    ep_model = ep_model.to(get_current_device())
-    tp_model = tp_model.to(get_current_device())
-    local_model = local_model.to(get_current_device())
+    ep_model = ep_model.to(get_accelerator().get_current_device())
+    tp_model = tp_model.to(get_accelerator().get_current_device())
+    local_model = local_model.to(get_accelerator().get_current_device())
 
     # sync ep param
     sync_moe_model_param(ep_model)
@@ -163,11 +164,11 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
     tp_grad_handler = MoeGradientHandler(tp_model)
 
     rank = dist.get_rank()
-    input_data = torch.randn(batch_size, dim, device=get_current_device())
+    input_data = torch.randn(batch_size, dim, device=get_accelerator().get_current_device())
     micro_batch_size = batch_size // world_size
     index = rank * micro_batch_size
     # NOTE: ep & tp takes in sharded data for each process
-    shard_data = input_data.detach()[index:index + micro_batch_size]
+    shard_data = input_data.detach()[index : index + micro_batch_size]
 
     out_local = local_model(input_data)
     MOE_MANAGER.reset_loss()
@@ -176,13 +177,15 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
     out_ep = ep_model(shard_data)
     MOE_MANAGER.reset_loss()
 
-    assert torch.allclose(out_tp, out_ep, atol=1e-6), \
-        f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_ep))}"
+    assert torch.allclose(
+        out_tp, out_ep, atol=1e-6
+    ), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_ep))}"
     try:
-        out_local_slice = out_local[index:index + micro_batch_size]
-        assert torch.allclose(out_ep, out_local_slice, atol=1e-6), \
-            f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local_slice))}"
-    except AssertionError as e:
+        out_local_slice = out_local[index : index + micro_batch_size]
+        assert torch.allclose(
+            out_ep, out_local_slice, atol=1e-6
+        ), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local_slice))}"
+    except AssertionError:
         """
         e.g., in local model, tokens = 4, capacity = 2, experts = 2, topk = 1
             router yields [01] --> [0], [23] --> [1], this is valid as capacity is 2
@@ -193,8 +196,7 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
             The same thing happens on router 1. And finally some tokens are dropped due to the sharded nature.
         """
         warnings.warn(
-            "EP & TP may result in different behavior from local model. "
-            "Please check the comments for details."
+            "EP & TP may result in different behavior from local model. " "Please check the comments for details."
         )
 
     out_local.mean().backward()
@@ -208,10 +210,9 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
     sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True)
     try:
         sync_local_from_ep(local_model, ep_model, assert_grad_flag=True)
-    except AssertionError as e:
+    except AssertionError:
         warnings.warn(
-            "EP & TP may result in different behavior from local model. "
-            "Please check the comments for details."
+            "EP & TP may result in different behavior from local model. " "Please check the comments for details."
         )
 
 
@@ -219,14 +220,17 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
 @pytest.mark.parametrize("num_experts", [4, 64])
 @pytest.mark.parametrize("batch_size", [16])
 @pytest.mark.parametrize("dim", [64])
-@pytest.mark.parametrize("config", [
-    {"enable_hierarchical_comm": False},
-    {"enable_hierarchical_comm": True},
-])
+@pytest.mark.parametrize(
+    "config",
+    [
+        {"enable_hierarchical_comm": False},
+        {"enable_hierarchical_comm": True},
+    ],
+)
 @rerun_if_address_is_in_use()
 def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, config: Dict):
     spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, config=config)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     test_moe_ep_tp(num_experts=8, batch_size=32, dim=32)
diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py
index 95c0e715d..2f08a335d 100644
--- a/tests/test_moe/test_moe_group.py
+++ b/tests/test_moe/test_moe_group.py
@@ -3,11 +3,11 @@ import torch.distributed as dist
 import torch.nn as nn
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.moe.experts import MLPExperts
 from colossalai.moe.manager import MOE_MANAGER
 from colossalai.moe.utils import sync_moe_model_param
 from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
-from colossalai.utils import get_current_device
 
 HIDDEN_SIZE = 4
 INTERMEDIATE_SIZE = 8
@@ -46,7 +46,7 @@ def run_moe_init(expert_parallel):
     assert dist.get_rank(parallel_info_dict[1].dp_group) == rank
 
     model = nn.ModuleList([exp0, exp1, exp2])
-    model = model.to(get_current_device())
+    model = model.to(get_accelerator().get_current_device())
     sync_moe_model_param(model)
 
     # MOE experts layout success when ep_size = 1
diff --git a/tests/test_optimizer/test_adam_kernel.py b/tests/test_optimizer/test_adam_kernel.py
index c136f78a1..2ff4b3016 100644
--- a/tests/test_optimizer/test_adam_kernel.py
+++ b/tests/test_optimizer/test_adam_kernel.py
@@ -8,7 +8,8 @@ import pytest
 import torch
 from torch import Tensor
 
-from colossalai.utils import get_current_device, multi_tensor_applier
+from colossalai.accelerator import get_accelerator
+from colossalai.utils import multi_tensor_applier
 
 _FUSED_ALLOWED_P_G_TYPES = [
     (torch.float, torch.half),
@@ -155,7 +156,9 @@ def test_fused_adam_kernel(adamw, weight_decay, p_dtype, g_dtype):
         rtol, atol = 1e-3, 1e-3
     if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16:
         rtol, atol = 4e-3, 4e-3
-    check_adam_kernel(FusedAdamKernel, adamw, weight_decay, p_dtype, g_dtype, get_current_device(), 3, rtol, atol)
+    check_adam_kernel(
+        FusedAdamKernel, adamw, weight_decay, p_dtype, g_dtype, get_accelerator().get_current_device(), 3, rtol, atol
+    )
 
 
 @pytest.mark.parametrize("adamw", [False, True])
diff --git a/tests/test_pipeline/test_p2p_communication.py b/tests/test_pipeline/test_p2p_communication.py
index 1665711ce..5ebe2a128 100644
--- a/tests/test_pipeline/test_p2p_communication.py
+++ b/tests/test_pipeline/test_p2p_communication.py
@@ -3,11 +3,11 @@ import torch
 import torch.distributed as dist
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.cluster import ProcessGroupMesh
 from colossalai.pipeline.p2p import PipelineP2PCommunication
 from colossalai.pipeline.stage_manager import PipelineStageManager
 from colossalai.testing import rerun_if_address_is_in_use, spawn
-from colossalai.utils import get_current_device
 
 
 def check_p2p_communication():
@@ -17,7 +17,7 @@ def check_p2p_communication():
 
     rank = dist.get_rank()
 
-    tensor = torch.ones(1, device=get_current_device())
+    tensor = torch.ones(1, device=get_accelerator().get_current_device())
 
     if rank == 0:
         p2p.send_forward(tensor)
diff --git a/tests/test_zero/test_gemini/test_chunkv2.py b/tests/test_zero/test_gemini/test_chunkv2.py
index 5977c706f..e4dc569b8 100644
--- a/tests/test_zero/test_gemini/test_chunkv2.py
+++ b/tests/test_zero/test_gemini/test_chunkv2.py
@@ -4,15 +4,15 @@ import torch.distributed as dist
 from torch.distributed.distributed_c10d import _get_default_group
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.tensor import ColoParameter
 from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
-from colossalai.utils import get_current_device
 from colossalai.zero.gemini import TensorState
 from colossalai.zero.gemini.chunk import Chunk
 
 
 def dist_sum(x):
-    temp = torch.tensor([x], device=get_current_device())
+    temp = torch.tensor([x], device=get_accelerator().get_current_device())
     dist.all_reduce(temp)
     return temp.item()
 
@@ -66,7 +66,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
         assert my_chunk.cpu_shard.size(0) == 1024 // world_size
         assert my_chunk.device_type == "cpu"
         assert my_chunk.can_move
-        my_chunk.shard_move(get_current_device())
+        my_chunk.shard_move(get_accelerator().get_current_device())
     else:
         assert my_chunk.cuda_global_chunk.size(0) == 1024
         assert my_chunk.device_type == "cuda"
diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py
index 21afff753..3a9742e01 100644
--- a/tests/test_zero/test_gemini/test_fwd_bwd.py
+++ b/tests/test_zero/test_gemini/test_fwd_bwd.py
@@ -5,11 +5,11 @@ from torch.nn.parallel import DistributedDataParallel as DDP
 from torch.testing import assert_close
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.amp import convert_to_apex_amp
 from colossalai.nn.optimizer import HybridAdam
 from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
 from colossalai.utils import set_seed
-from colossalai.utils.device import get_current_device
 from colossalai.zero import GeminiDDP, GeminiOptimizer
 from colossalai.zero.gemini.chunk import search_chunk_configuration
 from tests.kit.model_zoo import model_zoo, run_fwd_bwd
@@ -47,7 +47,7 @@ def exam_gpt_fwd_bwd(
     use_grad_checkpoint: bool = False,
     master_weights: bool = True,
 ):
-    init_device = get_current_device()
+    init_device = get_accelerator().get_current_device()
     model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
         iter(model_zoo.get_sub_registry(model_name).values())
     )
diff --git a/tests/test_zero/test_gemini/test_grad_accum.py b/tests/test_zero/test_gemini/test_grad_accum.py
index 35323e516..36a803492 100644
--- a/tests/test_zero/test_gemini/test_grad_accum.py
+++ b/tests/test_zero/test_gemini/test_grad_accum.py
@@ -6,10 +6,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP
 from torch.testing import assert_close
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.nn.optimizer import HybridAdam
 from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
 from colossalai.utils import set_seed
-from colossalai.utils.device import get_current_device
 from colossalai.zero import GeminiDDP, GeminiOptimizer
 from colossalai.zero.gemini.chunk import search_chunk_configuration
 from tests.kit.model_zoo import model_zoo, run_fwd
@@ -53,7 +53,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
 def exam_gemini_grad_acc(
     placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool
 ):
-    init_device = get_current_device()
+    init_device = get_accelerator().get_current_device()
     model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
         iter(model_zoo.get_sub_registry(model_name).values())
     )
diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py
index 152bf2895..7f3c7176e 100644
--- a/tests/test_zero/test_gemini/test_inference.py
+++ b/tests/test_zero/test_gemini/test_inference.py
@@ -7,11 +7,11 @@ from torch.nn.parallel import DistributedDataParallel as DDP
 from torch.testing import assert_close
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.amp import convert_to_apex_amp
 from colossalai.nn.optimizer import HybridAdam
 from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
 from colossalai.utils import set_seed
-from colossalai.utils.device import get_current_device
 from colossalai.zero import GeminiDDP, GeminiOptimizer
 from colossalai.zero.gemini.chunk import search_chunk_configuration
 from tests.kit.model_zoo import model_zoo, run_fwd, run_fwd_bwd
@@ -47,7 +47,9 @@ def multi_chunk_init(model: torch.nn.Module, placement_config: dict):
 
 
 def single_chunk_init(model: torch.nn.Module, placement_config: dict):
-    model = GeminiDDP(model, chunk_init_device=get_current_device(), pin_memory=True, **placement_config)
+    model = GeminiDDP(
+        model, chunk_init_device=get_accelerator().get_current_device(), pin_memory=True, **placement_config
+    )
     return model
 
 
@@ -63,7 +65,7 @@ def exam_inference(placement_config: dict, model_name: str, model_init_func: Cal
     torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
     torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
     torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
-    init_dev = get_current_device()
+    init_dev = get_accelerator().get_current_device()
     model = model_builder().to(init_dev)
 
     for torch_p, p in zip(torch_model.parameters(), model.parameters()):
diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py
index 405d7d789..71bb27b4a 100644
--- a/tests/test_zero/test_gemini/test_optim.py
+++ b/tests/test_zero/test_gemini/test_optim.py
@@ -5,11 +5,11 @@ from torch.nn.parallel import DistributedDataParallel as DDP
 from torch.testing import assert_close
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.legacy.amp import convert_to_apex_amp
 from colossalai.nn.optimizer import HybridAdam
 from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
 from colossalai.utils import set_seed
-from colossalai.utils.device import get_current_device
 from colossalai.zero import GeminiDDP, GeminiOptimizer
 from colossalai.zero.gemini.chunk import search_chunk_configuration
 from tests.kit.model_zoo import model_zoo, run_fwd_bwd
@@ -150,7 +150,7 @@ def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch.
 
     model = GeminiDDP(
         model,
-        chunk_init_device=get_current_device(),
+        chunk_init_device=get_accelerator().get_current_device(),
         search_range_m=1,
         pin_memory=True,
         mixed_precision=mixed_precision,
diff --git a/tests/test_zero/test_gemini/test_search.py b/tests/test_zero/test_gemini/test_search.py
index e99f6d59b..cf3658bf9 100644
--- a/tests/test_zero/test_gemini/test_search.py
+++ b/tests/test_zero/test_gemini/test_search.py
@@ -2,8 +2,8 @@ import pytest
 import torch
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.testing import rerun_if_address_is_in_use, spawn
-from colossalai.utils import get_current_device
 from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration
 from tests.kit.model_zoo import model_zoo
 
@@ -34,7 +34,7 @@ def exam_chunk_manager():
     sharded_ddp_model = model_builder()
     chunk_manager = init_chunk_manager(
         sharded_ddp_model,
-        get_current_device(),
+        get_accelerator().get_current_device(),
         hidden_dim=128,
         search_range_m=1,
         min_chunk_size_m=0,
diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py
index 351ae5f67..11f738615 100644
--- a/tests/test_zero/test_low_level/test_grad_acc.py
+++ b/tests/test_zero/test_low_level/test_grad_acc.py
@@ -7,9 +7,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP
 from torch.testing import assert_close
 
 import colossalai
+from colossalai.accelerator import get_accelerator
 from colossalai.testing import spawn
 from colossalai.testing.random import seed_all
-from colossalai.utils import conditional_context, get_current_device
+from colossalai.utils import conditional_context
 from colossalai.zero import LowLevelZeroOptimizer
 
 
@@ -28,7 +29,7 @@ class MlpModel(nn.Module):
 def exam_zero_1_2_grad_acc():
     local_rank = torch.distributed.get_rank()
     seed_all(2009)
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
     # create model
     zero1_model = MlpModel().to(device)
     zero2_model = copy.deepcopy(zero1_model)
@@ -71,7 +72,7 @@ def exam_zero_1_2_grad_acc():
 def exam_zero_1_grad_acc(sync):
     local_rank = torch.distributed.get_rank()
     seed_all(2008)
-    device = get_current_device()
+    device = get_accelerator().get_current_device()
 
     # create models
     zero_model = MlpModel()