diff --git a/private_gpt/components/llm/prompt_helper.py b/private_gpt/components/llm/prompt_helper.py index 1f31fe20..77158200 100644 --- a/private_gpt/components/llm/prompt_helper.py +++ b/private_gpt/components/llm/prompt_helper.py @@ -176,16 +176,18 @@ class MistralPromptStyle(AbstractPromptStyle): inst_buffer = [] text = "" for message in messages: - if message.role == MessageRole.SYSTEM: + if message.role == MessageRole.SYSTEM or message.role == MessageRole.USER: inst_buffer.append(str(message.content).strip()) - elif message.role == MessageRole.USER: - inst_buffer.append(str(message.content).strip()) - text += "[/INST] " + "\n".join(inst_buffer) + " [/INST]" - inst_buffer.clear() elif message.role == MessageRole.ASSISTANT: + text += "[INST] " + "\n".join(inst_buffer) + " [/INST]" text += " " + str(message.content).strip() + "" + inst_buffer.clear() else: raise ValueError(f"Unknown message role {message.role}") + + if len(inst_buffer) > 0: + text += "[INST] " + "\n".join(inst_buffer) + " [/INST]" + return text def _completion_to_prompt(self, completion: str) -> str: diff --git a/tests/test_prompt_helper.py b/tests/test_prompt_helper.py index 3b5af914..ef764370 100644 --- a/tests/test_prompt_helper.py +++ b/tests/test_prompt_helper.py @@ -69,17 +69,21 @@ def test_tag_prompt_style_format_with_system_prompt(): def test_mistral_prompt_style_format(): prompt_style = MistralPromptStyle() messages = [ - ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM), - ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), + ChatMessage(content="A", role=MessageRole.SYSTEM), + ChatMessage(content="B", role=MessageRole.USER), ] - - expected_prompt = ( - "[INST] You are an AI assistant. [/INST]" - "[INST] Hello, how are you doing? [/INST]" - ) - + expected_prompt = "[INST] A\nB [/INST]" assert prompt_style.messages_to_prompt(messages) == expected_prompt + messages2 = [ + ChatMessage(content="A", role=MessageRole.SYSTEM), + ChatMessage(content="B", role=MessageRole.USER), + ChatMessage(content="C", role=MessageRole.ASSISTANT), + ChatMessage(content="D", role=MessageRole.USER), + ] + expected_prompt2 = "[INST] A\nB [/INST] C[INST] D [/INST]" + assert prompt_style.messages_to_prompt(messages2) == expected_prompt2 + def test_chatml_prompt_style_format(): prompt_style = ChatMLPromptStyle()