diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py
index 6ef53c9d1..cd5b70d11 100644
--- a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py
+++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Tuple
 from torch import Tensor
 from torch.fx import Graph, Node
 
+from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
 from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
 from colossalai.fx.profiler import (
     activation_size,
@@ -131,8 +132,14 @@ class CheckpointSolverRotor(CheckpointSolverBase):
         fwd_mem_peak = 0
         for n in node:
             assert isinstance(n, Node), f'{n} is not a Node'
-            xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
-            fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n))
+            if n.target == runtime_apply or n.target == runtime_comm_spec_apply:
+                # in this case we need to calculate memory usage directly based on the statics that hooked in node.meta
+                xbar += n.meta['fwd_mem_out']
+                fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'])
+            else:
+                xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
+                fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n))
+
             # minimum flop count is required
             ftime += max(calculate_fwd_time(n), 1.0)
             btime += max(calculate_bwd_time(n), 1.0)
diff --git a/colossalai/auto_parallel/passes/meta_info_prop.py b/colossalai/auto_parallel/passes/meta_info_prop.py
index bdeaeffed..f7e07ef1e 100644
--- a/colossalai/auto_parallel/passes/meta_info_prop.py
+++ b/colossalai/auto_parallel/passes/meta_info_prop.py
@@ -151,6 +151,7 @@ class MetaInfoProp:
         # fetch other memory informations
         memory_cost = meta_info.memory_cost
         graph_info.fwd_mem_tmp = memory_cost.fwd.temp
+        graph_info.fwd_mem_out = memory_cost.fwd.activation
         graph_info.bwd_mem_tmp = memory_cost.bwd.temp
         graph_info.bwd_mem_out = memory_cost.bwd.activation
 
diff --git a/colossalai/fx/profiler/shard_utils.py b/colossalai/fx/profiler/shard_utils.py
index a765e5055..34feefb43 100644
--- a/colossalai/fx/profiler/shard_utils.py
+++ b/colossalai/fx/profiler/shard_utils.py
@@ -100,7 +100,7 @@ def calculate_fwd_time(n: Node) -> float:
         fwd_time (float): the result of `fwd_time`
     """
     # TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
-    return n.meta["fwd_flop"]
+    return n.meta["fwd_time"]
 
 
 def calculate_bwd_time(n: Node) -> float:
@@ -111,4 +111,4 @@ def calculate_bwd_time(n: Node) -> float:
         bwd_time (float): the result of `bwd_time`
     """
     # TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
-    return n.meta["bwd_flop"]
+    return n.meta["bwd_time"]