From b10339df7cfac1eee6945df37a80fd1c38f42289 Mon Sep 17 00:00:00 2001
From: BurkeHulk <hangxu0304@gmail.com>
Date: Mon, 21 Oct 2024 13:55:43 +0800
Subject: [PATCH 1/2] fix lora ckpt save format (ColoTensor to Tensor)

---
 colossalai/booster/plugin/low_level_zero_plugin.py        | 3 ++-
 colossalai/booster/plugin/torch_ddp_plugin.py             | 6 +++++-
 colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py | 5 ++++-
 3 files changed, 11 insertions(+), 3 deletions(-)

diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py
index b167b5c7a..97fabe63a 100644
--- a/colossalai/booster/plugin/low_level_zero_plugin.py
+++ b/colossalai/booster/plugin/low_level_zero_plugin.py
@@ -290,7 +290,8 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
         assert isinstance(
             peft_model, PeftModel
         ), "The model doesn't have lora adapters, please enable lora before saving."
-        return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
+        return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors,
+                                          state_dict = tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()))
 
 
 class LowLevelZeroPlugin(DPPluginBase):
diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py
index ec7ce7f9a..aa4d35cd4 100644
--- a/colossalai/booster/plugin/torch_ddp_plugin.py
+++ b/colossalai/booster/plugin/torch_ddp_plugin.py
@@ -1,10 +1,12 @@
 from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
 
+import torch
 import torch.nn as nn
 from torch.nn.parallel import DistributedDataParallel as DDP
 from torch.optim import Optimizer
 from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
 from torch.utils.data import DataLoader
+from torch.utils._pytree import tree_map
 
 from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
 from colossalai.cluster import DistCoordinator
@@ -134,7 +136,9 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
             assert isinstance(
                 peft_model, PeftModel
             ), "The model doesn't have lora adapters, please enable lora before saving."
-            peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors)
+            return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors,
+                                              state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x,
+                                                                  peft_model.state_dict()))
 
 
 class TorchDDPModel(ModelWrapper):
diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
index 3b6917d32..4ca1353d8 100644
--- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
+++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
@@ -11,6 +11,7 @@ import torch.distributed as dist
 import torch.nn as nn
 from torch.distributed import ProcessGroup
 from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+from torch.utils._pytree import tree_map
 
 from colossalai.cluster import DistCoordinator
 from colossalai.interface import ModelWrapper, OptimizerWrapper
@@ -956,4 +957,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
         assert isinstance(
             peft_model, PeftModel
         ), "The model doesn't have lora adapters, please enable lora before saving."
-        return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
+        return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors,
+                                          state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x,
+                                                              peft_model.state_dict()))

From 6d6cafabe28e6665c35fb585ac162237e09a789b Mon Sep 17 00:00:00 2001
From: BurkeHulk <hangxu0304@gmail.com>
Date: Mon, 21 Oct 2024 14:04:32 +0800
Subject: [PATCH 2/2] pre-commit fix

---
 colossalai/booster/plugin/low_level_zero_plugin.py     |  7 +++++--
 colossalai/booster/plugin/torch_ddp_plugin.py          | 10 ++++++----
 .../checkpoint_io/hybrid_parallel_checkpoint_io.py     |  8 +++++---
 3 files changed, 16 insertions(+), 9 deletions(-)

diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py
index 97fabe63a..f3a6901ad 100644
--- a/colossalai/booster/plugin/low_level_zero_plugin.py
+++ b/colossalai/booster/plugin/low_level_zero_plugin.py
@@ -290,8 +290,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
         assert isinstance(
             peft_model, PeftModel
         ), "The model doesn't have lora adapters, please enable lora before saving."
-        return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors,
-                                          state_dict = tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()))
+        return peft_model.save_pretrained(
+            checkpoint,
+            safe_serialization=use_safetensors,
+            state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
+        )
 
 
 class LowLevelZeroPlugin(DPPluginBase):
diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py
index aa4d35cd4..156a4acf9 100644
--- a/colossalai/booster/plugin/torch_ddp_plugin.py
+++ b/colossalai/booster/plugin/torch_ddp_plugin.py
@@ -5,8 +5,8 @@ import torch.nn as nn
 from torch.nn.parallel import DistributedDataParallel as DDP
 from torch.optim import Optimizer
 from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
-from torch.utils.data import DataLoader
 from torch.utils._pytree import tree_map
+from torch.utils.data import DataLoader
 
 from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
 from colossalai.cluster import DistCoordinator
@@ -136,9 +136,11 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
             assert isinstance(
                 peft_model, PeftModel
             ), "The model doesn't have lora adapters, please enable lora before saving."
-            return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors,
-                                              state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x,
-                                                                  peft_model.state_dict()))
+            return peft_model.save_pretrained(
+                checkpoint,
+                safe_serialization=use_safetensors,
+                state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
+            )
 
 
 class TorchDDPModel(ModelWrapper):
diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
index 4ca1353d8..e6abf59e3 100644
--- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
+++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
@@ -957,6 +957,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
         assert isinstance(
             peft_model, PeftModel
         ), "The model doesn't have lora adapters, please enable lora before saving."
-        return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors,
-                                          state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x,
-                                                              peft_model.state_dict()))
+        return peft_model.save_pretrained(
+            checkpoint,
+            safe_serialization=use_safetensors,
+            state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
+        )