diff --git a/private_gpt/components/llm/prompt_helper.py b/private_gpt/components/llm/prompt_helper.py
index 985d217b..3e452216 100644
--- a/private_gpt/components/llm/prompt_helper.py
+++ b/private_gpt/components/llm/prompt_helper.py
@@ -173,18 +173,20 @@ class TagPromptStyle(AbstractPromptStyle):
class MistralPromptStyle(AbstractPromptStyle):
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
- prompt = ""
+ inst_buffer = []
+ text = ""
for message in messages:
- role = message.role
- content = message.content or ""
- if role.lower() == "system":
- message_from_user = f"[INST] {content.strip()} [/INST]"
- prompt += message_from_user
- elif role.lower() == "user":
- prompt += ""
- message_from_user = f"[INST] {content.strip()} [/INST]"
- prompt += message_from_user
- return prompt
+ if message.role == MessageRole.SYSTEM:
+ inst_buffer.append(message.content.strip())
+ elif message.role == MessageRole.USER:
+ inst_buffer.append(message.content.strip())
+ text += "[/INST] " + "\n".join(inst_buffer) + " [/INST]"
+ inst_buffer.clear()
+ elif message.role == MessageRole.ASSISTANT:
+ text += " " + message.content.strip() + ""
+ else:
+ raise ValueError(f"Unknown message role {message.role}")
+ return text
def _completion_to_prompt(self, completion: str) -> str:
return self._messages_to_prompt(