diff --git a/applications/ChatGPT/chatgpt/trainer/rm.py b/applications/ChatGPT/chatgpt/trainer/rm.py index c24289502..b76ae5373 100644 --- a/applications/ChatGPT/chatgpt/trainer/rm.py +++ b/applications/ChatGPT/chatgpt/trainer/rm.py @@ -43,7 +43,7 @@ class RewardModelTrainer(ABC): # train if use_lora > 0: print("Using Lora") - lora.mark_only_lora_as_trainable(self.model) + lora.mark_only_lora_as_trainable(self.model.model) else: self.model.train() for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader: