diff --git a/applications/Chat/coati/models/lora.py b/applications/Chat/coati/models/lora.py
index 2114913e1..e9bd7b2ed 100644
--- a/applications/Chat/coati/models/lora.py
+++ b/applications/Chat/coati/models/lora.py
@@ -1,4 +1,6 @@
+import dataclasses
 import math
+import warnings
 from typing import Optional
 
 import loralib as lora
@@ -7,6 +9,14 @@ import torch.nn as nn
 import torch.nn.functional as F
 
 
+@dataclasses.dataclass
+class LoRAManager:
+    merge_weights: bool = False
+
+
+LORA_MANAGER = LoRAManager()
+
+
 class LoraLinear(lora.LoRALayer, nn.Module):
     """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
 
@@ -17,13 +27,11 @@ class LoraLinear(lora.LoRALayer, nn.Module):
         r: int = 0,
         lora_alpha: int = 1,
         lora_dropout: float = 0.0,
-        fan_in_fan_out: bool = False,  # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
-        merge_weights: bool = True,
+        # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
+        fan_in_fan_out: bool = False,
     ):
         nn.Module.__init__(self)
-        lora.LoRALayer.__init__(
-            self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
-        )
+        lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
         self.weight = weight
         self.bias = bias
 
@@ -53,31 +61,31 @@ class LoraLinear(lora.LoRALayer, nn.Module):
         def T(w):
             return w.T if self.fan_in_fan_out else w
 
-        nn.Module.train(self, mode)
-        if self.merge_weights and self.merged:
-            # Make sure that the weights are not merged
-            if self.r > 0:
-                if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
-                    # FIXME(csric): temporary fix
-                    self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
-                    self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
-                    self.reset_parameters()
-                else:
-                    self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
-            self.merged = False
+        self.training = mode
+        if LORA_MANAGER.merge_weights:
+            if mode and self.merged:
+                warnings.warn("Invoke module.train() would unmerge LoRA weights.")
+                raise NotImplementedError("LoRA unmerge is not tested.")
+                # Make sure that the weights are not merged
+                if self.r > 0:
+                    if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
+                        # FIXME(csric): temporary fix
+                        self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
+                        self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
+                        self.reset_parameters()
+                    else:
+                        self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
+                self.merged = False
+            elif not mode and not self.merged:
+                warnings.warn("Invoke module.eval() would merge LoRA weights.")
+                # Merge the weights and mark it
+                if self.r > 0:
+                    self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
+                    delattr(self, "lora_A")
+                    delattr(self, "lora_B")
+                self.merged = True
 
-    def eval(self):
-        def T(w):
-            return w.T if self.fan_in_fan_out else w
-
-        nn.Module.eval(self)
-        if self.merge_weights and not self.merged:
-            # Merge the weights and mark it
-            if self.r > 0:
-                self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
-                delattr(self, "lora_A")
-                delattr(self, "lora_B")
-            self.merged = True
+        return self
 
     def forward(self, x: torch.Tensor):
         def T(w):
@@ -96,7 +104,7 @@ def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
     assert (
         lora_rank <= linear.in_features
     ), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
-    lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
+    lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank)
     return lora_linear
 
 
diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py
index de2a33263..a8ab15eeb 100644
--- a/applications/Chat/examples/train_prompts.py
+++ b/applications/Chat/examples/train_prompts.py
@@ -192,6 +192,12 @@ def main(args):
         use_wandb=args.use_wandb,
     )
 
+    if args.lora_rank > 0 and args.merge_lora_weights:
+        from coati.models.lora import LORA_MANAGER
+
+        # NOTE: set model to eval to merge LoRA weights
+        LORA_MANAGER.merge_weights = True
+        actor.eval()
     # save model checkpoint after fitting
     strategy.save_model(actor, args.save_path, only_rank0=True)
     # save optimizer checkpoint on all ranks
@@ -227,6 +233,7 @@ if __name__ == "__main__":
     parser.add_argument("--ptx_batch_size", type=int, default=1)
     parser.add_argument("--experience_batch_size", type=int, default=8)
     parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+    parser.add_argument("--merge_lora_weights", type=bool, default=True)
     parser.add_argument("--lr", type=float, default=1e-7)
     parser.add_argument("--kl_coef", type=float, default=0.1)
     parser.add_argument("--ptx_coef", type=float, default=0.9)
diff --git a/applications/Chat/examples/train_reward_model.py b/applications/Chat/examples/train_reward_model.py
index c9095b365..c1be51f2f 100644
--- a/applications/Chat/examples/train_reward_model.py
+++ b/applications/Chat/examples/train_reward_model.py
@@ -157,6 +157,13 @@ def train(args):
         log_dir=args.log_dir,
         use_wandb=args.use_wandb,
     )
+
+    if args.lora_rank > 0 and args.merge_lora_weights:
+        from coati.models.lora import LORA_MANAGER
+
+        # NOTE: set model to eval to merge LoRA weights
+        LORA_MANAGER.merge_weights = True
+        model.eval()
     # save model checkpoint after fitting on only rank0
     strategy.save_model(model, args.save_path, only_rank0=True)
     # save optimizer checkpoint on all ranks
@@ -186,6 +193,7 @@ if __name__ == "__main__":
     parser.add_argument("--batch_size", type=int, default=1)
     parser.add_argument("--max_len", type=int, default=512)
     parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+    parser.add_argument("--merge_lora_weights", type=bool, default=True)
     parser.add_argument("--lr", type=float, default=9e-6)
     parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"])
     parser.add_argument("--log_dir", default="logs", type=str)
diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py
index a34661762..4f36791be 100644
--- a/applications/Chat/examples/train_sft.py
+++ b/applications/Chat/examples/train_sft.py
@@ -177,6 +177,12 @@ def train(args):
         use_wandb=args.use_wandb,
     )
 
+    if args.lora_rank > 0 and args.merge_lora_weights:
+        from coati.models.lora import LORA_MANAGER
+
+        # NOTE: set model to eval to merge LoRA weights
+        LORA_MANAGER.merge_weights = True
+        model.eval()
     # save model checkpoint after fitting on only rank0
     strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
     # save optimizer checkpoint on all ranks
@@ -204,6 +210,7 @@ if __name__ == "__main__":
     parser.add_argument("--batch_size", type=int, default=4)
     parser.add_argument("--max_len", type=int, default=512)
     parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+    parser.add_argument("--merge_lora_weights", type=bool, default=True)
     parser.add_argument("--lr", type=float, default=5e-6)
     parser.add_argument("--accumulation_steps", type=int, default=8)
     parser.add_argument("--log_dir", default="logs", type=str)