diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py
index 12ffe5fe5..761947344 100644
--- a/colossalai/booster/plugin/low_level_zero_plugin.py
+++ b/colossalai/booster/plugin/low_level_zero_plugin.py
@@ -142,7 +142,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
                 from colossalai.utils.safetensors import save_nested
 
                 f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread")
-                save_nested(f_writer, state_dict["state"], {"param_groups": state_dict["param_groups"]})
+                save_nested(f_writer, state_dict)
                 self.async_writers.append(f_writer)
             else:
                 save_state_dict(state_dict, checkpoint, use_safetensors=False)
diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py
index 68fb582e5..f10945763 100644
--- a/colossalai/nn/optimizer/cpu_adam.py
+++ b/colossalai/nn/optimizer/cpu_adam.py
@@ -81,6 +81,14 @@ class CPUAdam(NVMeOptimizer):
         # if you find yourself stuck here, make sure that you install colossalai with BUILD_EXT=1 specification
         self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
 
+    def load_state_dict(self, state_dict):
+        super().load_state_dict(state_dict)
+        for group in self.param_groups:
+            for p in group["params"]:
+                state = self.state[p]
+                if "step" in state and isinstance(state["step"], torch.Tensor):
+                    state["step"] = int(state["step"].item())
+
     def torch_adam_update(
         self,
         data,
diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py
index 4cbb01163..8f9cce246 100644
--- a/colossalai/testing/comparison.py
+++ b/colossalai/testing/comparison.py
@@ -1,4 +1,4 @@
-from typing import Any, List, OrderedDict, Tuple
+from typing import Any, List, OrderedDict
 
 import torch
 import torch.distributed as dist
@@ -78,9 +78,7 @@ def check_state_dict_equal(
                 v1 = v1.to(v2.dtype)
             assert_close_loose(v1, v2)
         else:
-            if isinstance(v1, Tuple) and not isinstance(v2, Tuple):
-                v2 = tuple(v2)
-            assert v1 == v2, f"{v1} not equals to {v2}. {type(v1)}, {type(v2)}"
+            assert v1 == v2, f"{v1} not equals to {v2}"
 
 
 def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py
index ad7d3be77..8b8cb627f 100644
--- a/colossalai/utils/safetensors.py
+++ b/colossalai/utils/safetensors.py
@@ -1,6 +1,5 @@
 # a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
 import json
-import warnings
 from dataclasses import asdict, dataclass
 from typing import Dict, List, Optional, Tuple
 
@@ -12,6 +11,26 @@ try:
 except ModuleNotFoundError:
     raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
 _TYPES_INV = {v: k for k, v in _TYPES.items()}
+import io
+
+from torch.distributed.distributed_c10d import _pickler, _unpickler
+
+
+def _object_to_tensor(obj, device):
+    f = io.BytesIO()
+    _pickler(f).dump(obj)
+    byte_storage = torch.ByteStorage._from_buffer(f.getvalue())  # type: ignore[attr-defined]
+    # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
+    # Otherwise, it will casue 100X slowdown.
+    # See: https://github.com/pytorch/pytorch/issues/65696
+    byte_tensor = torch.ByteTensor(byte_storage).to(device)
+    return byte_tensor
+
+
+def _tensor_to_object(tensor, tensor_size):
+    tensor = tensor.cpu()
+    buf = tensor.numpy().tobytes()[:tensor_size]
+    return _unpickler(io.BytesIO(buf)).load()
 
 
 @dataclass
@@ -28,49 +47,68 @@ class PreparedData:
     offset: int
 
 
-def flatten_dict(nested_dict, parent_key="", separator="^"):
-    """
-    Flatten a nested dictionary, generating a flattened dictionary where the keys are joined by the specified separator.
-
-    nested_dict: The input nested dictionary.
-    parent_key: The parent key currently being processed.
-    separator: The separator used to join keys, default is '_', but can be customized to another symbol. :return: A flattened dictionary."
-    """
-    items = []
-    for k, v in nested_dict.items():
-        new_key = f"{parent_key}{separator}{k}" if parent_key else str(k)
-        if isinstance(v, dict):
-            items.extend(flatten_dict(v, new_key, separator).items())
-        else:
-            v = torch.tensor(v, dtype=torch.float16) if not isinstance(v, torch.Tensor) else v
-            items.append((new_key, v))
-
-    return dict(items)
+def _cast_to_tensor(obj):
+    if isinstance(obj, torch.Tensor):
+        return obj
+    return _object_to_tensor(obj, "cpu")
 
 
-def unflatten_dict(flattened_dict, separator="^"):
-    """
-    Restore a flattened dictionary back to a multi-level nested dictionary.
+def _cast_to_object(tensor: torch.Tensor):
+    return _tensor_to_object(tensor, tensor.numel() * tensor.element_size())
 
-    flattened_dict: The flattened dictionary.
-    separator: The separator used during flattening, default is '_', but can be customized to another symbol. :return: The restored nested dictionary.
-    """
-    nested_dict = {}
-    for key, value in flattened_dict.items():
-        keys = key.split(separator)
-        try:
-            keys[0] = int(keys[0])
-        except ValueError:
-            warnings.warn(f"{key[0]} can't convert to integer")
-        d = nested_dict
-        for part in keys[:-1]:
-            if part not in d:
-                d[part] = {}
-            d = d[part]
-        assert isinstance(value, torch.Tensor)
-        d[keys[-1]] = value
 
-    return nested_dict
+def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[dict, Optional[dict]]:
+    flat_dict = {}
+    non_tensor_keys = []
+    if "state" in state_dict:
+        # 3-level dict
+        states = state_dict["state"]
+    else:
+        # 2-level dict, usually for optimizer state dict shard
+        states = state_dict
+
+    for idx, d in states.items():
+        for k, v in d.items():
+            nested_key = f"state{seperator}{idx}{seperator}{k}"
+            if not isinstance(v, torch.Tensor):
+                non_tensor_keys.append(nested_key)
+            flat_dict[nested_key] = _cast_to_tensor(v)
+    if "param_groups" in state_dict:
+        flat_dict["param_groups"] = _cast_to_tensor(state_dict["param_groups"])
+        non_tensor_keys.append("param_groups")
+    if len(non_tensor_keys) > 0:
+        metadata = {"non_tensor_keys": non_tensor_keys}
+    else:
+        metadata = None
+    return flat_dict, metadata
+
+
+def _unflatten_optim_state_dict(flat_dict: dict, metadata: Optional[dict] = None, seperator: str = "."):
+    state_dict = {}
+    if metadata is not None:
+        non_tensor_keys = json.loads(metadata["non_tensor_keys"])
+    else:
+        non_tensor_keys = []
+    flat_dict = {k: _cast_to_object(v) if k in non_tensor_keys else v for k, v in flat_dict.items()}
+    if "param_groups" in flat_dict:
+        # 3-level dict
+        state_dict["param_groups"] = flat_dict.pop("param_groups")
+        state_dict["state"] = {}
+        states = state_dict["state"]
+    else:
+        # 2-level dict, usually for optimizer state dict shard
+        states = state_dict
+
+    for k, v in flat_dict.items():
+        parts = k.split(seperator)
+        assert len(parts) == 3 and parts[0] == "state"
+        idx = int(parts[1])
+        key = parts[2]
+        if idx not in states:
+            states[idx] = {}
+        states[idx][key] = v
+
+    return state_dict
 
 
 def prepare(
@@ -124,10 +162,8 @@ def save(
         f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
 
 
-def save_nested(
-    f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
-) -> None:
-    flatten_data = flatten_dict(state_dict)
+def save_nested(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
+    flatten_data, metadata = _flatten_optim_state_dict(state_dict)
     save(f_writer, flatten_data, metadata)
 
 
@@ -154,10 +190,5 @@ def load_flat(checkpoint_path):
     with safe_open(checkpoint_path, framework="pt") as f:
         metadata = f.metadata()
     state_dict_load = load_file(checkpoint_path)
-    state_dict = unflatten_dict(state_dict_load)
-    if metadata is None:
-        return state_dict
-    metadata = dict(map(lambda item: (item[0], json.loads(item[1])), metadata.items()))
-    combined_state_dict = {"state": state_dict}
-    combined_state_dict.update(metadata)
-    return combined_state_dict
+    state_dict = _unflatten_optim_state_dict(state_dict_load, metadata)
+    return state_dict
diff --git a/tests/test_checkpoint_io/test_safetensors_async_io.py b/tests/test_checkpoint_io/test_safetensors_async_io.py
index 31c69e961..521ec10bd 100644
--- a/tests/test_checkpoint_io/test_safetensors_async_io.py
+++ b/tests/test_checkpoint_io/test_safetensors_async_io.py
@@ -1,9 +1,9 @@
 import tempfile
-from copy import deepcopy
 
 import torch
+from safetensors.torch import load_file
 
-from colossalai.utils.safetensors import load_flat, save_nested
+from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested
 
 try:
     from tensornvme.async_file_io import AsyncFileWriter
@@ -11,17 +11,29 @@ except ModuleNotFoundError:
     raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
 
 from colossalai.testing import check_state_dict_equal
+from colossalai.utils import get_current_device
 
 
 def test_save_load():
     with tempfile.TemporaryDirectory() as tempdir:
         optimizer_state_dict = {
-            0: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
-            1: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
-            2: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
-        }
-        # group_dict = {"param_groups": [0, 1, 2]}
-        group_dict = {
+            "state": {
+                0: {
+                    "step": torch.tensor(1.0),
+                    "exp_avg": torch.rand((1024, 1024)),
+                    "exp_avg_sq": torch.rand((1024, 1024)),
+                },
+                1: {
+                    "step": torch.tensor(1.0),
+                    "exp_avg": torch.rand((1024, 1024)),
+                    "exp_avg_sq": torch.rand((1024, 1024)),
+                },
+                2: {
+                    "step": torch.tensor(1.0),
+                    "exp_avg": torch.rand((1024, 1024)),
+                    "exp_avg_sq": torch.rand((1024, 1024)),
+                },
+            },
             "param_groups": [
                 {
                     "lr": 0.001,
@@ -94,22 +106,26 @@ def test_save_load():
                         61,
                     ],
                 }
-            ]
+            ],
         }
-        metadata = deepcopy(group_dict)
+
         optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
         f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread")
-
-        save_nested(f_writer, optimizer_state_dict, metadata)
+        save_nested(f_writer, optimizer_state_dict)
         f_writer.sync_before_step()
         f_writer.synchronize()
         f_writer.fp.close()
-
         load_state_dict = load_flat(optimizer_saved_path)
-        state_dict = load_state_dict["state"]
-        group = {"param_groups": load_state_dict["param_groups"]}
-        check_state_dict_equal(optimizer_state_dict, state_dict)
-        check_state_dict_equal(group_dict, group)
+        check_state_dict_equal(load_state_dict, optimizer_state_dict)
+
+        optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors"
+        f_writer = AsyncFileWriter(fp=open(optimizer_shard_saved_path, "wb"), n_entries=191, backend="pthread")
+        save_nested(f_writer, optimizer_state_dict["state"])
+        f_writer.sync_before_step()
+        f_writer.synchronize()
+        f_writer.fp.close()
+        load_state_dict_shard = load_flat(optimizer_shard_saved_path)
+        check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"])
 
         model_state_dict = {
             "module.weight0": torch.rand((1024, 1024)),
@@ -118,10 +134,20 @@ def test_save_load():
         }
         model_saved_path = f"{tempdir}/save_model.safetensors"
         f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
-        save_nested(f_writer, model_state_dict)
+        save(f_writer, model_state_dict)
         f_writer.sync_before_step()
         f_writer.synchronize()
         f_writer.fp.close()
-
-        load_state_dict = load_flat(model_saved_path)
+        load_state_dict = load_file(model_saved_path)
+        check_state_dict_equal(model_state_dict, load_state_dict)
+
+        model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()}
+        model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
+        model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
+        f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
+        move_and_save(f_writer, model_state_dict_cuda, model_state_pinned)
+        f_writer.sync_before_step()
+        f_writer.synchronize()
+        f_writer.fp.close()
+        load_state_dict = load_file(model_saved_path)
         check_state_dict_equal(model_state_dict, load_state_dict)