diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py
index 99e77f7b9..70ad39b67 100644
--- a/colossalai/checkpoint_io/general_checkpoint_io.py
+++ b/colossalai/checkpoint_io/general_checkpoint_io.py
@@ -8,13 +8,11 @@ from typing import Optional
 import torch.nn as nn
 from torch.optim import Optimizer
 
-from colossalai.utils.safetensors import move_and_save
-
 from .checkpoint_io_base import CheckpointIO
 from .index_file import CheckpointIndexFile
 from .utils import (
+    async_save_state_dict,
     async_save_state_dict_shards,
-    create_pinned_state_dict,
     get_model_base_filenames,
     get_optimizer_base_filenames,
     is_safetensors_available,
@@ -59,13 +57,16 @@ class GeneralCheckpointIO(CheckpointIO):
             pass
 
         if use_async:
-            from tensornvme.async_file_io import AsyncFileWriter
-
-            writer = AsyncFileWriter(open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread")
-            if id(model) not in self.pinned_state_dicts:
-                self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
-            self.async_writers.append(writer)
-            move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
+            pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
+            new_pinned_state_dict, writers = async_save_state_dict(
+                state_dict,
+                checkpoint,
+                pinned_state_dict,
+                self.N_WRITE_ENTRIES,
+                shard_preprocess=False,
+            )
+            self.pinned_state_dicts[id(model)] = new_pinned_state_dict
+            self.async_writers.extend(writers)
 
         else:
             # save the checkpoint
diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py
index eb8bb2dcf..c8a4c6c1d 100644
--- a/colossalai/checkpoint_io/utils.py
+++ b/colossalai/checkpoint_io/utils.py
@@ -19,7 +19,7 @@ from colossalai.tensor.d_tensor import (
     to_global,
     to_global_for_customized_distributed_tensor,
 )
-from colossalai.utils.safetensors import move_and_save
+from colossalai.utils.safetensors import move_and_save, save
 
 SAFE_WEIGHTS_NAME = "model.safetensors"
 WEIGHTS_NAME = "pytorch_model.bin"
@@ -266,6 +266,38 @@ def save_state_dict_shards(
     return total_size
 
 
+def async_save_state_dict(
+    state_dict: dict,
+    checkpoint_file_path: str,
+    pinned_state_dict: Optional[Dict[str, torch.Tensor]],
+    n_write_entries: int,
+    shard_preprocess: bool = False,
+    move: bool = True,
+):
+    from tensornvme.async_file_io import AsyncFileWriter
+
+    async_writers = []
+
+    saved_state_dict, metadata = state_dict, None
+    if pinned_state_dict is None:
+        pinned_state_dict = create_pinned_state_dict(saved_state_dict)
+
+    f_writer = AsyncFileWriter(fp=open(checkpoint_file_path, "wb"), n_entries=n_write_entries, backend="pthread")
+    if move:
+        move_and_save(
+            f_writer,
+            state_dict=saved_state_dict,
+            metadata=metadata,
+            state_dict_pinned=pinned_state_dict,
+        )
+    else:
+        for name, tensor in saved_state_dict.items():
+            pinned_state_dict[name].copy_(tensor)
+        save(f_writer=f_writer, state_dict=pinned_state_dict, metadata=metadata)
+    async_writers.append(f_writer)
+    return pinned_state_dict, async_writers
+
+
 def async_save_state_dict_shards(
     sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
     checkpoint: str,
diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py
index 8b8cb627f..39ef8bbc9 100644
--- a/colossalai/utils/safetensors.py
+++ b/colossalai/utils/safetensors.py
@@ -170,9 +170,10 @@ def save_nested(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor])
 def move_and_save(
     f_writer: AsyncFileWriter,
     state_dict: Dict[str, torch.Tensor],
+    metadata: Optional[Dict[str, str]] = None,
     state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None,
 ) -> None:
-    prepared_data, _, tensor_keys = prepare(state_dict)
+    prepared_data, _, tensor_keys = prepare(state_dict, metadata)
     n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
 
     f_writer.write(n.to_bytes(8, byteorder="little"))