From 58d8b8a2dd9a92c1dab3a44d2a35fb30716437c5 Mon Sep 17 00:00:00 2001
From: Hongxin Liu <lhx0217@gmail.com>
Date: Fri, 18 Oct 2024 16:48:52 +0800
Subject: [PATCH] [misc] fit torch api upgradation and remove legecy import
 (#6093)

* [amp] fit torch's new api

* [amp] fix api call

* [amp] fix api call

* [misc] fit torch pytree api upgrade

* [misc] remove legacy import

* [misc] fit torch amp api

* [misc] fit torch amp api
---
 colossalai/accelerator/cuda_accelerator.py            |  2 +-
 colossalai/kernel/jit/option.py                       |  2 +-
 colossalai/pipeline/schedule/_utils.py                | 10 ++++++++--
 .../zero/gemini/memory_tracer/runtime_mem_tracer.py   | 11 ++++++-----
 colossalai/zero/gemini/placement_policy.py            |  3 ++-
 .../features/mixed_precision_training_with_booster.md |  2 +-
 .../features/mixed_precision_training_with_booster.md |  2 +-
 7 files changed, 20 insertions(+), 12 deletions(-)

diff --git a/colossalai/accelerator/cuda_accelerator.py b/colossalai/accelerator/cuda_accelerator.py
index f1ab487d4..32e62b33f 100644
--- a/colossalai/accelerator/cuda_accelerator.py
+++ b/colossalai/accelerator/cuda_accelerator.py
@@ -279,4 +279,4 @@ class CudaAccelerator(BaseAccelerator):
         """
         Return autocast function
         """
-        return torch.cuda.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
+        return torch.amp.autocast(device_type="cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
diff --git a/colossalai/kernel/jit/option.py b/colossalai/kernel/jit/option.py
index d392649a6..1ee93e4e0 100644
--- a/colossalai/kernel/jit/option.py
+++ b/colossalai/kernel/jit/option.py
@@ -1,7 +1,6 @@
 import torch
 
 from colossalai.accelerator import get_accelerator
-from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear
 
 from .bias_dropout_add import bias_dropout_add_fused_train
 from .bias_gelu import bias_gelu_impl
@@ -45,6 +44,7 @@ def warmup_jit_fusion(
     dtype: torch.dtype = torch.float32,
 ):
     """Compile JIT functions before the main training steps"""
+    from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear
 
     embed = Embedding(vocab_size, hidden_size).to(get_accelerator().get_current_device())
     linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_accelerator().get_current_device())
diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py
index 271b3238f..8f42a9014 100644
--- a/colossalai/pipeline/schedule/_utils.py
+++ b/colossalai/pipeline/schedule/_utils.py
@@ -3,8 +3,9 @@ from typing import Any, List, Optional, Tuple
 
 import torch
 import torch.cuda
+from packaging.version import Version
 from torch.nn import Module
-from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, _register_pytree_node, tree_flatten, tree_map, tree_unflatten
+from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, tree_flatten, tree_map, tree_unflatten
 
 
 # this register are for torch under version 1.13.1, maybe removed in the future
@@ -16,7 +17,12 @@ def _odict_unflatten(values: List[Any], context: Any) -> "OrderedDict[Any, Any]"
     return OrderedDict((key, value) for key, value in zip(context, values))
 
 
-_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
+if Version(torch.__version__) <= Version("1.13.1"):
+    try:
+        from torch.utils._pytree import register_pytree_node as _register_pytree_node
+    except ImportError:
+        from torch.utils._pytree import _register_pytree_node
+    _register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
 
 
 def tree_map_hf(fn: Any, pytree: Any):
diff --git a/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py
index b0d258824..81520326f 100644
--- a/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py
+++ b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py
@@ -1,10 +1,5 @@
 import torch.nn
 
-from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import (
-    GradMemStats,
-    GradMemTracerHook,
-    ParamMemTracerHook,
-)
 from colossalai.tensor.param_op_hook import ColoParamOpHookManager
 from colossalai.utils import _cast_float
 
@@ -27,6 +22,12 @@ class RuntimeMemTracer:
 
     def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half):
         super().__init__()
+        from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import (
+            GradMemStats,
+            GradMemTracerHook,
+            ParamMemTracerHook,
+        )
+
         self.module = module
         self.dtype = dtype
         self._gradstat = GradMemStats()
diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py
index 178755d03..2aa8dc3f6 100644
--- a/colossalai/zero/gemini/placement_policy.py
+++ b/colossalai/zero/gemini/placement_policy.py
@@ -8,7 +8,6 @@ import torch
 import torch.distributed as dist
 
 from colossalai.accelerator import get_accelerator
-from colossalai.legacy.utils.memory import colo_device_memory_capacity
 from colossalai.zero.gemini.chunk import Chunk
 
 from .chunk import Chunk, ChunkManager
@@ -172,6 +171,8 @@ class AutoPlacementPolicy(PlacementPolicy):
         Returns:
             int: the volume of memory that is evicted
         """
+        from colossalai.legacy.utils.memory import colo_device_memory_capacity
+
         start = time()
         cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
         used_cuda_model_data = self.chunk_manager.total_mem["cuda"]
diff --git a/docs/source/en/features/mixed_precision_training_with_booster.md b/docs/source/en/features/mixed_precision_training_with_booster.md
index 65304b1f4..1e17c2bb5 100644
--- a/docs/source/en/features/mixed_precision_training_with_booster.md
+++ b/docs/source/en/features/mixed_precision_training_with_booster.md
@@ -16,7 +16,7 @@ Author: [Mingyan Jiang](https://github.com/jiangmingyan)
 AMP stands for automatic mixed precision training.
 In Colossal-AI, we have incorporated different implementations of mixed precision training:
 
-1. torch.cuda.amp
+1. torch.amp
 2. apex.amp
 3. naive amp
 
diff --git a/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md b/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md
index da377ceb2..93a69830c 100644
--- a/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md
+++ b/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md
@@ -16,7 +16,7 @@
 AMP 代表自动混合精度训练。
 在 Colossal-AI 中, 我们结合了混合精度训练的不同实现:
 
-1. torch.cuda.amp
+1. torch.amp
 2. apex.amp
 3. naive amp