diff --git a/private_gpt/components/llm/prompt_helper.py b/private_gpt/components/llm/prompt_helper.py index 985d217b..3e452216 100644 --- a/private_gpt/components/llm/prompt_helper.py +++ b/private_gpt/components/llm/prompt_helper.py @@ -173,18 +173,20 @@ class TagPromptStyle(AbstractPromptStyle): class MistralPromptStyle(AbstractPromptStyle): def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: - prompt = "" + inst_buffer = [] + text = "" for message in messages: - role = message.role - content = message.content or "" - if role.lower() == "system": - message_from_user = f"[INST] {content.strip()} [/INST]" - prompt += message_from_user - elif role.lower() == "user": - prompt += "" - message_from_user = f"[INST] {content.strip()} [/INST]" - prompt += message_from_user - return prompt + if message.role == MessageRole.SYSTEM: + inst_buffer.append(message.content.strip()) + elif message.role == MessageRole.USER: + inst_buffer.append(message.content.strip()) + text += "[/INST] " + "\n".join(inst_buffer) + " [/INST]" + inst_buffer.clear() + elif message.role == MessageRole.ASSISTANT: + text += " " + message.content.strip() + "" + else: + raise ValueError(f"Unknown message role {message.role}") + return text def _completion_to_prompt(self, completion: str) -> str: return self._messages_to_prompt(