diff --git a/private_gpt/components/llm/prompt_helper.py b/private_gpt/components/llm/prompt_helper.py index 58754119..1e868685 100644 --- a/private_gpt/components/llm/prompt_helper.py +++ b/private_gpt/components/llm/prompt_helper.py @@ -170,36 +170,53 @@ class Llama3PromptStyle(AbstractPromptStyle): """ def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: - prompt = "" + prompt = self.BOS # Start with BOS token has_system_message = False for i, message in enumerate(messages): if not message or message.content is None: continue + if message.role == MessageRole.SYSTEM: - prompt += f"{self.B_SYS}\n\n{message.content.strip()}{self.E_SYS}" + prompt += f"{self.B_SYS}\n\n{message.content.strip()}{self.EOT}" # Use EOT for system message has_system_message = True + elif message.role == MessageRole.USER: + prompt += f"{self.B_INST}user{self.E_INST}\n\n{message.content.strip()}{self.EOT}" + elif message.role == MessageRole.ASSISTANT: + # Check if this is a tool call + if message.additional_kwargs and message.additional_kwargs.get("type") == "tool_call": + tool_call_content = message.content + prompt += f"{self.B_INST}tool_code{self.E_INST}\n\n{tool_call_content}{self.EOT}" + else: + prompt += f"{self.ASSISTANT_INST}\n\n{message.content.strip()}{self.EOT}" + elif message.role == MessageRole.TOOL: + # Assuming additional_kwargs['type'] == 'tool_result' + # and message.content contains the result of the tool call + tool_result_content = message.content + prompt += f"{self.B_INST}tool_output{self.E_INST}\n\n{tool_result_content}{self.EOT}" else: + # Fallback for unknown roles (though ideally all roles should be handled) role_header = f"{self.B_INST}{message.role.value}{self.E_INST}" prompt += f"{role_header}\n\n{message.content.strip()}{self.EOT}" - # Add assistant header if the last message is not from the assistant - if i == len(messages) - 1 and message.role != MessageRole.ASSISTANT: - prompt += f"{self.ASSISTANT_INST}\n\n" - - # Add default system prompt if no system message was provided + # Add default system prompt if no system message was provided at the beginning if not has_system_message: - prompt = ( - f"{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.E_SYS}" + prompt - ) + default_system_prompt_str = f"{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT.strip()}{self.EOT}" + prompt = self.BOS + default_system_prompt_str + prompt[len(self.BOS):] # Insert after BOS - # TODO: Implement tool handling logic + # Add assistant header if the model should generate a response + # This is typically when the last message is not from the assistant, + # or when the last message is a tool result. + if messages and (messages[-1].role != MessageRole.ASSISTANT or + (messages[-1].role == MessageRole.TOOL)): # If last message was tool result + prompt += f"{self.ASSISTANT_INST}\n\n" return prompt def _completion_to_prompt(self, completion: str) -> str: + # Ensure BOS is at the start, followed by system prompt, then user message, then assistant prompt return ( - f"{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.E_SYS}" + f"{self.BOS}{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT.strip()}{self.EOT}" f"{self.B_INST}user{self.E_INST}\n\n{completion.strip()}{self.EOT}" f"{self.ASSISTANT_INST}\n\n" ) @@ -213,49 +230,88 @@ class TagPromptStyle(AbstractPromptStyle): <|system|>: your system prompt here. <|user|>: user message here (possibly with context and question) - <|assistant|>: assistant (model) response here. + <|assistant|>: assistant (model) response here. ``` - - FIXME: should we add surrounding `` and `` tags, like in llama2? """ + BOS, EOS = "", "" + def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: - """Format message to prompt with `<|ROLE|>: MSG` style.""" - prompt = "" + """Format message to prompt with `<|ROLE|>: MSG` style, including BOS/EOS.""" + prompt_parts = [] for message in messages: - role = message.role - content = message.content or "" - message_from_user = f"<|{role.lower()}|>: {content.strip()}" - message_from_user += "\n" - prompt += message_from_user - # we are missing the last <|assistant|> tag that will trigger a completion + role_str = str(message.role).lower() + content_str = str(message.content).strip() if message.content else "" + + formatted_message = f"<|{role_str}|>: {content_str}" + if message.role == MessageRole.ASSISTANT: + formatted_message += self.EOS # EOS after assistant's message + prompt_parts.append(formatted_message) + + if not messages: + # If there are no messages, start with BOS and prompt for assistant. + # This assumes the typical case where the user would initiate. + # _completion_to_prompt handles the user-initiated start. + # If system is to start, a system message should be in `messages`. + # So, if messages is empty, it implies we want to prompt for an assistant response + # to an implicit (or empty) user turn. + return f"{self.BOS}<|assistant|>: " + + # Join messages with newline, start with BOS + prompt = self.BOS + "\n".join(prompt_parts) + + # Always end with a prompt for the assistant to speak, ensure it's on a new line + if not prompt.endswith("\n"): + prompt += "\n" prompt += "<|assistant|>: " return prompt def _completion_to_prompt(self, completion: str) -> str: - return self._messages_to_prompt( - [ChatMessage(content=completion, role=MessageRole.USER)] - ) + # A completion is a user message. + # Format: <|user|>: {completion_content}\n<|assistant|>: + content_str = str(completion).strip() + return f"{self.BOS}<|user|>: {content_str}\n<|assistant|>: " class MistralPromptStyle(AbstractPromptStyle): def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: - inst_buffer = [] - text = "" - for message in messages: - if message.role == MessageRole.SYSTEM or message.role == MessageRole.USER: - inst_buffer.append(str(message.content).strip()) + prompt = "" + current_instruction_parts = [] + + for i, message in enumerate(messages): + content = str(message.content).strip() if message.content else "" + # Skip empty non-assistant messages. Assistant messages can be empty (e.g. for function calling). + if not content and message.role != MessageRole.ASSISTANT: + logger.debug("MistralPromptStyle: Skipping empty non-assistant message.") + continue + + if message.role == MessageRole.USER or message.role == MessageRole.SYSTEM: + current_instruction_parts.append(content) elif message.role == MessageRole.ASSISTANT: - text += "[INST] " + "\n".join(inst_buffer) + " [/INST]" - text += " " + str(message.content).strip() + "" - inst_buffer.clear() + if not current_instruction_parts and i == 0: + # First message is assistant, skip. + logger.warning( + "MistralPromptStyle: First message is from assistant, skipping." + ) + continue + if current_instruction_parts: + # Only add if prompt is empty, otherwise, assistant responses follow user turns. + bos_token = "" if not prompt else "" + prompt += bos_token + "[INST] " + "\n".join(current_instruction_parts) + " [/INST]" + current_instruction_parts = [] + # Assistant content can be empty, e.g. for tool calls that will be handled later + prompt += " " + content + "" else: - raise ValueError(f"Unknown message role {message.role}") + logger.warning( + f"MistralPromptStyle: Unknown message role {message.role} encountered. Skipping." + ) - if len(inst_buffer) > 0: - text += "[INST] " + "\n".join(inst_buffer) + " [/INST]" + # If there are pending instructions (i.e., last message was user/system) + if current_instruction_parts: + bos_token = "" if not prompt else "" + prompt += bos_token + "[INST] " + "\n".join(current_instruction_parts) + " [/INST]" - return text + return prompt def _completion_to_prompt(self, completion: str) -> str: return self._messages_to_prompt( @@ -265,17 +321,27 @@ class MistralPromptStyle(AbstractPromptStyle): class ChatMLPromptStyle(AbstractPromptStyle): def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: - prompt = "<|im_start|>system\n" + prompt = "" for message in messages: - role = message.role - content = message.content or "" - if role.lower() == "system": - message_from_user = f"{content.strip()}" - prompt += message_from_user - elif role.lower() == "user": - prompt += "<|im_end|>\n<|im_start|>user\n" - message_from_user = f"{content.strip()}<|im_end|>\n" - prompt += message_from_user + role = str(message.role).lower() # Ensure role is a string and lowercase + content = str(message.content).strip() if message.content else "" + + # According to the ChatML documentation, messages are formatted as: + # <|im_start|>role_name + # content + # <|im_end|> + # There should be a newline after role_name and before <|im_end|>. + # And a newline after <|im_end|> to separate messages. + + # Skip empty messages if content is crucial. + # For ChatML, even an empty content string is typically included. + # if not content and role not in ("assistant"): # Allow assistant to have empty content for prompting + # logger.debug(f"ChatMLPromptStyle: Skipping empty message from {role}") + # continue + + prompt += f"<|im_start|>{role}\n{content}<|im_end|>\n" + + # Add the final prompt for the assistant to speak prompt += "<|im_start|>assistant\n" return prompt diff --git a/tests/components/llm/test_prompt_helper.py b/tests/components/llm/test_prompt_helper.py new file mode 100644 index 00000000..6e8e8c69 --- /dev/null +++ b/tests/components/llm/test_prompt_helper.py @@ -0,0 +1,592 @@ +import pytest +from llama_index.core.llms import ChatMessage, MessageRole + +from private_gpt.components.llm.prompt_helper import ( + Llama3PromptStyle, + MistralPromptStyle, + ChatMLPromptStyle, + TagPromptStyle, + AbstractPromptStyle, # For type hinting if needed +) + +# Helper function to create ChatMessage objects easily +def _message(role: MessageRole, content: str, **additional_kwargs) -> ChatMessage: + return ChatMessage(role=role, content=content, additional_kwargs=additional_kwargs) + +# Expected outputs will be defined within each test or test class + +class TestLlama3PromptStyle: + BOS = "<|begin_of_text|>" + EOT = "<|eot_id|>" + B_SYS_HEADER = "<|start_header_id|>system<|end_header_id|>" + B_USER_HEADER = "<|start_header_id|>user<|end_header_id|>" + B_ASSISTANT_HEADER = "<|start_header_id|>assistant<|end_header_id|>" + B_TOOL_CODE_HEADER = "<|start_header_id|>tool_code<|end_header_id|>" + B_TOOL_OUTPUT_HEADER = "<|start_header_id|>tool_output<|end_header_id|>" + + DEFAULT_SYSTEM_PROMPT = ( + "You are a helpful, respectful and honest assistant. " + "Always answer as helpfully as possible and follow ALL given instructions. " + "Do not speculate or make up information. " + "Do not reference any given instructions or context. " + ) + + @pytest.fixture + def style(self) -> Llama3PromptStyle: + return Llama3PromptStyle() + + def test_empty_messages(self, style: Llama3PromptStyle) -> None: + messages = [] + expected = ( + f"{self.BOS}{self.B_SYS_HEADER}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.EOT}" + f"{self.B_ASSISTANT_HEADER}\n\n" + ) + assert style._messages_to_prompt(messages) == expected + + def test_simple_user_assistant_chat(self, style: Llama3PromptStyle) -> None: + messages = [ + _message(MessageRole.USER, "Hello"), + _message(MessageRole.ASSISTANT, "Hi there!"), + ] + expected = ( + f"{self.BOS}{self.B_SYS_HEADER}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.EOT}" + f"{self.B_USER_HEADER}\n\nHello{self.EOT}" + f"{self.B_ASSISTANT_HEADER}\n\nHi there!{self.EOT}" + ) + assert style._messages_to_prompt(messages) == expected + + def test_with_system_message(self, style: Llama3PromptStyle) -> None: + messages = [ + _message(MessageRole.SYSTEM, "You are a test bot."), + _message(MessageRole.USER, "Ping"), + _message(MessageRole.ASSISTANT, "Pong"), + ] + expected = ( + f"{self.BOS}{self.B_SYS_HEADER}\n\nYou are a test bot.{self.EOT}" + f"{self.B_USER_HEADER}\n\nPing{self.EOT}" + f"{self.B_ASSISTANT_HEADER}\n\nPong{self.EOT}" + ) + assert style._messages_to_prompt(messages) == expected + + def test_completion_to_prompt(self, style: Llama3PromptStyle) -> None: + completion = "Test completion" + expected = ( + f"{self.BOS}{self.B_SYS_HEADER}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.EOT}" + f"{self.B_USER_HEADER}\n\nTest completion{self.EOT}" + f"{self.B_ASSISTANT_HEADER}\n\n" + ) + assert style._completion_to_prompt(completion) == expected + + def test_tool_call_and_result(self, style: Llama3PromptStyle) -> None: + messages = [ + _message(MessageRole.USER, "What's the weather in Paris?"), + _message( + MessageRole.ASSISTANT, + content=None, # LlamaIndex might put tool call details here, or just use additional_kwargs + additional_kwargs={"type": "tool_call", "tool_call_id": "123", "name": "get_weather", "arguments": '{"location": "Paris"}'} + ), + _message( + MessageRole.TOOL, + content='{"temperature": "20C"}', + additional_kwargs={"type": "tool_result", "tool_call_id": "123", "name": "get_weather"} + ), + ] + # Note: The current Llama3PromptStyle implementation uses message.content for tool call/result content. + # If additional_kwargs are used to structure tool calls (like OpenAI), the style needs to be adapted. + # For this test, we assume content holds the direct string for tool_code and tool_output. + # Let's adjust the messages based on current implementation that uses .content for tool_code/output + messages_for_current_impl = [ + _message(MessageRole.USER, "What's the weather in Paris?"), + _message( + MessageRole.ASSISTANT, + content='get_weather({"location": "Paris"})', # Simplified tool call content + additional_kwargs={"type": "tool_call"} + ), + _message( + MessageRole.TOOL, + content='{"temperature": "20C"}', + additional_kwargs={"type": "tool_result"} # No specific tool_call_id or name used by current style from additional_kwargs + ), + ] + expected = ( + f"{self.BOS}{self.B_SYS_HEADER}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.EOT}" + f"{self.B_USER_HEADER}\n\nWhat's the weather in Paris?{self.EOT}" + f"{self.B_TOOL_CODE_HEADER}\n\nget_weather({{\"location\": \"Paris\"}}){self.EOT}" # Content is 'get_weather(...)' + f"{self.B_TOOL_OUTPUT_HEADER}\n\n{{\"temperature\": \"20C\"}}{self.EOT}" + f"{self.B_ASSISTANT_HEADER}\n\n" # Assistant should respond after tool result + ) + assert style._messages_to_prompt(messages_for_current_impl) == expected + + def test_multiple_interactions_with_tools(self, style: Llama3PromptStyle) -> None: + messages = [ + _message(MessageRole.USER, "Can you search for prompt engineering techniques?"), + _message(MessageRole.ASSISTANT, content="Okay, I will search for that.", additional_kwargs={}), # Normal assistant message + _message(MessageRole.ASSISTANT, content='search_web({"query": "prompt engineering techniques"})', additional_kwargs={"type": "tool_call"}), + _message(MessageRole.TOOL, content='[Result 1: ...]', additional_kwargs={"type": "tool_result"}), + _message(MessageRole.ASSISTANT, content="I found one result. Should I look for more?", additional_kwargs={}), + _message(MessageRole.USER, "Yes, please find another one."), + ] + expected = ( + f"{self.BOS}{self.B_SYS_HEADER}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.EOT}" + f"{self.B_USER_HEADER}\n\nCan you search for prompt engineering techniques?{self.EOT}" + f"{self.B_ASSISTANT_HEADER}\n\nOkay, I will search for that.{self.EOT}" + f"{self.B_TOOL_CODE_HEADER}\n\nsearch_web({{\"query\": \"prompt engineering techniques\"}}){self.EOT}" + f"{self.B_TOOL_OUTPUT_HEADER}\n\n[Result 1: ...]{self.EOT}" + f"{self.B_ASSISTANT_HEADER}\n\nI found one result. Should I look for more?{self.EOT}" + f"{self.B_USER_HEADER}\n\nYes, please find another one.{self.EOT}" + f"{self.B_ASSISTANT_HEADER}\n\n" + ) + assert style._messages_to_prompt(messages) == expected + + def test_ending_with_user_message(self, style: Llama3PromptStyle) -> None: + messages = [ + _message(MessageRole.USER, "First message"), + _message(MessageRole.ASSISTANT, "First response"), + _message(MessageRole.USER, "Second message, expecting response"), + ] + expected = ( + f"{self.BOS}{self.B_SYS_HEADER}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.EOT}" + f"{self.B_USER_HEADER}\n\nFirst message{self.EOT}" + f"{self.B_ASSISTANT_HEADER}\n\nFirst response{self.EOT}" + f"{self.B_USER_HEADER}\n\nSecond message, expecting response{self.EOT}" + f"{self.B_ASSISTANT_HEADER}\n\n" + ) + assert style._messages_to_prompt(messages) == expected + + def test_ending_with_tool_result(self, style: Llama3PromptStyle) -> None: + messages = [ + _message(MessageRole.USER, "Find info on X."), + _message(MessageRole.ASSISTANT, content='search({"topic": "X"})', additional_kwargs={"type": "tool_call"}), + _message(MessageRole.TOOL, content="Info about X found.", additional_kwargs={"type": "tool_result"}), + ] + expected = ( + f"{self.BOS}{self.B_SYS_HEADER}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.EOT}" + f"{self.B_USER_HEADER}\n\nFind info on X.{self.EOT}" + f"{self.B_TOOL_CODE_HEADER}\n\nsearch({{\"topic\": \"X\"}}){self.EOT}" + f"{self.B_TOOL_OUTPUT_HEADER}\n\nInfo about X found.{self.EOT}" + f"{self.B_ASSISTANT_HEADER}\n\n" + ) + assert style._messages_to_prompt(messages) == expected + + def test_message_with_empty_content(self, style: Llama3PromptStyle) -> None: + # Llama3PromptStyle skips messages with None content, but not necessarily empty string. + # Let's test with an empty string for user, and None for assistant (which should be skipped) + messages = [ + _message(MessageRole.USER, ""), # Empty string content + _message(MessageRole.ASSISTANT, None), # None content, should be skipped + _message(MessageRole.USER, "Follow up") + ] + expected = ( + f"{self.BOS}{self.B_SYS_HEADER}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.EOT}" + f"{self.B_USER_HEADER}\n\n{self.EOT}" # Empty content for user + f"{self.B_USER_HEADER}\n\nFollow up{self.EOT}" # Assistant message was skipped + f"{self.B_ASSISTANT_HEADER}\n\n" + ) + # The style's loop: `if not message or message.content is None: continue` + # An empty string `""` is not `None`, so it should be included. + assert style._messages_to_prompt(messages) == expected + + def test_system_message_not_first(self, style: Llama3PromptStyle) -> None: + messages = [ + _message(MessageRole.USER, "Hello"), + _message(MessageRole.SYSTEM, "System message in the middle (unusual)."), + _message(MessageRole.ASSISTANT, "Hi there!"), + ] + # The current implementation processes system messages whenever they appear. + # If a system message appears, it sets `has_system_message = True`. + # If NO system message appears, a default one is prepended. + # If one DOES appear, it's used, and default is not prepended. + expected = ( + f"{self.BOS}" + # Default system prompt is NOT added because a system message IS present. + f"{self.B_USER_HEADER}\n\nHello{self.EOT}" + f"{self.B_SYS_HEADER}\n\nSystem message in the middle (unusual).{self.EOT}" + f"{self.B_ASSISTANT_HEADER}\n\nHi there!{self.EOT}" + ) + assert style._messages_to_prompt(messages) == expected + + +class TestMistralPromptStyle: + @pytest.fixture + def style(self) -> MistralPromptStyle: + return MistralPromptStyle() + + def test_empty_messages(self, style: MistralPromptStyle) -> None: + messages = [] + # The refactored version should produce an empty string if no instructions are pending. + # Or, if it were to prompt for something, it might be "[INST] [/INST]" or just "" + # Based on current refactored code: if current_instruction_parts is empty, it returns prompt ("") + assert style._messages_to_prompt(messages) == "" + + def test_simple_user_assistant_chat(self, style: MistralPromptStyle) -> None: + messages = [ + _message(MessageRole.USER, "Hello"), + _message(MessageRole.ASSISTANT, "Hi there!"), + ] + expected = "[INST] Hello [/INST] Hi there!" + assert style._messages_to_prompt(messages) == expected + + def test_with_system_message(self, style: MistralPromptStyle) -> None: + # System messages are treated like user messages in the current Mistral impl + messages = [ + _message(MessageRole.SYSTEM, "You are helpful."), + _message(MessageRole.USER, "Ping"), + _message(MessageRole.ASSISTANT, "Pong"), + ] + expected = "[INST] You are helpful.\nPing [/INST] Pong" + assert style._messages_to_prompt(messages) == expected + + def test_completion_to_prompt(self, style: MistralPromptStyle) -> None: + completion = "Test completion" + # This will call _messages_to_prompt with [ChatMessage(role=USER, content="Test completion")] + expected = "[INST] Test completion [/INST]" + assert style._completion_to_prompt(completion) == expected + + def test_consecutive_user_messages(self, style: MistralPromptStyle) -> None: + messages = [ + _message(MessageRole.USER, "First part."), + _message(MessageRole.USER, "Second part."), + _message(MessageRole.ASSISTANT, "Understood."), + ] + expected = "[INST] First part.\nSecond part. [/INST] Understood." + assert style._messages_to_prompt(messages) == expected + + def test_ending_with_user_message(self, style: MistralPromptStyle) -> None: + messages = [ + _message(MessageRole.USER, "Hi"), + _message(MessageRole.ASSISTANT, "Hello"), + _message(MessageRole.USER, "How are you?"), + ] + # Note: The previous prompt had "[INST] Hi [/INST] Hello" + # The new user message should start a new [INST] block if prompt was not empty. + # Current logic: bos_token = "" if not prompt else "" + # Since prompt is not empty after "Hello", bos_token will be "". + # This might be an issue with the current Mistral refactor if strict BOS per turn is needed. + # The current refactored code for Mistral: + # prompt += bos_token + "[INST] " + "\n".join(current_instruction_parts) + " [/INST]" + # If prompt = "[INST] Hi [/INST] Hello", then bos_token is "", so it becomes: + # "[INST] Hi [/INST] Hello[INST] How are you? [/INST]" -> This seems correct for continued conversation. + expected = "[INST] Hi [/INST] Hello[INST] How are you? [/INST]" + assert style._messages_to_prompt(messages) == expected + + def test_initial_assistant_message_skipped(self, style: MistralPromptStyle, caplog) -> None: + messages = [ + _message(MessageRole.ASSISTANT, "I speak first!"), + _message(MessageRole.USER, "Oh, hello there."), + ] + # The first assistant message should be skipped with a warning. + # The prompt should then start with the user message. + expected = "[INST] Oh, hello there. [/INST]" + with caplog.at_level("WARNING"): + assert style._messages_to_prompt(messages) == expected + assert "MistralPromptStyle: First message is from assistant, skipping." in caplog.text + + def test_multiple_assistant_messages_in_a_row(self, style: MistralPromptStyle) -> None: + messages = [ + _message(MessageRole.USER, "User message"), + _message(MessageRole.ASSISTANT, "Assistant first response."), + _message(MessageRole.ASSISTANT, "Assistant second response (after no user message)."), + ] + # current_instruction_parts will be empty when processing the second assistant message. + # The logic is: + # if current_instruction_parts: prompt += bos_token + "[INST] " + "\n".join(current_instruction_parts) + " [/INST]" + # prompt += " " + content + "" + # So, it will correctly append the second assistant message without a new [INST] + expected = ("[INST] User message [/INST] Assistant first response." + " Assistant second response (after no user message).") + assert style._messages_to_prompt(messages) == expected + + def test_system_user_assistant_alternating(self, style: MistralPromptStyle) -> None: + messages = [ + _message(MessageRole.SYSTEM, "System setup."), + _message(MessageRole.USER, "User query 1."), + _message(MessageRole.ASSISTANT, "Assistant answer 1."), + _message(MessageRole.USER, "User query 2."), # System messages are part of INST with user + _message(MessageRole.ASSISTANT, "Assistant answer 2."), + ] + expected = ("[INST] System setup.\nUser query 1. [/INST] Assistant answer 1." + "[INST] User query 2. [/INST] Assistant answer 2.") + assert style._messages_to_prompt(messages) == expected + + def test_empty_content_messages(self, style: MistralPromptStyle, caplog) -> None: + messages = [ + _message(MessageRole.USER, "Hello"), + _message(MessageRole.USER, None), # Skipped by `if not content and message.role != MessageRole.ASSISTANT:` + _message(MessageRole.USER, ""), # Skipped by the same logic + _message(MessageRole.ASSISTANT, "Hi"), + _message(MessageRole.ASSISTANT, ""), # Empty assistant message, kept + _message(MessageRole.ASSISTANT, None),# Empty assistant message, kept (content becomes "") + ] + # The refactored code skips empty non-assistant messages. + # Empty assistant messages (content="" or content=None) are kept. + expected = ("[INST] Hello [/INST] Hi" + " " # From assistant with content="" + " ") # From assistant with content=None (becomes "") + + with caplog.at_level("DEBUG"): # The skipping messages are logged at DEBUG level + actual = style._messages_to_prompt(messages) + assert actual == expected + # Check that specific debug messages for skipping are present + assert "Skipping empty non-assistant message." in caplog.text # For the None and "" user messages + + +class TestChatMLPromptStyle: + IM_START = "<|im_start|>" + IM_END = "<|im_end|>" + + @pytest.fixture + def style(self) -> ChatMLPromptStyle: + return ChatMLPromptStyle() + + def test_empty_messages(self, style: ChatMLPromptStyle) -> None: + messages = [] + # Expected: just the final assistant prompt + expected = f"{self.IM_START}assistant\n" + assert style._messages_to_prompt(messages) == expected + + def test_simple_user_assistant_chat(self, style: ChatMLPromptStyle) -> None: + messages = [ + _message(MessageRole.USER, "Hello"), + _message(MessageRole.ASSISTANT, "Hi there!"), + ] + expected = ( + f"{self.IM_START}user\nHello{self.IM_END}\n" + f"{self.IM_START}assistant\nHi there!{self.IM_END}\n" + f"{self.IM_START}assistant\n" + ) + assert style._messages_to_prompt(messages) == expected + + def test_with_system_message(self, style: ChatMLPromptStyle) -> None: + messages = [ + _message(MessageRole.SYSTEM, "You are ChatML bot."), + _message(MessageRole.USER, "Ping"), + _message(MessageRole.ASSISTANT, "Pong"), + ] + expected = ( + f"{self.IM_START}system\nYou are ChatML bot.{self.IM_END}\n" + f"{self.IM_START}user\nPing{self.IM_END}\n" + f"{self.IM_START}assistant\nPong{self.IM_END}\n" + f"{self.IM_START}assistant\n" + ) + assert style._messages_to_prompt(messages) == expected + + def test_completion_to_prompt(self, style: ChatMLPromptStyle) -> None: + completion = "Test user input" + # This will call _messages_to_prompt with [ChatMessage(role=USER, content="Test user input")] + expected = ( + f"{self.IM_START}user\nTest user input{self.IM_END}\n" + f"{self.IM_START}assistant\n" + ) + assert style._completion_to_prompt(completion) == expected + + def test_multiple_turns(self, style: ChatMLPromptStyle) -> None: + messages = [ + _message(MessageRole.USER, "First user message."), + _message(MessageRole.ASSISTANT, "First assistant response."), + _message(MessageRole.USER, "Second user message."), + _message(MessageRole.ASSISTANT, "Second assistant response."), + ] + expected = ( + f"{self.IM_START}user\nFirst user message.{self.IM_END}\n" + f"{self.IM_START}assistant\nFirst assistant response.{self.IM_END}\n" + f"{self.IM_START}user\nSecond user message.{self.IM_END}\n" + f"{self.IM_START}assistant\nSecond assistant response.{self.IM_END}\n" + f"{self.IM_START}assistant\n" + ) + assert style._messages_to_prompt(messages) == expected + + def test_message_with_empty_content(self, style: ChatMLPromptStyle) -> None: + # ChatML typically includes messages even with empty content. + messages = [ + _message(MessageRole.USER, "Hello"), + _message(MessageRole.ASSISTANT, ""), # Empty string content + _message(MessageRole.USER, "Follow up") + ] + expected = ( + f"{self.IM_START}user\nHello{self.IM_END}\n" + f"{self.IM_START}assistant\n{self.IM_END}\n" # Empty content for assistant + f"{self.IM_START}user\nFollow up{self.IM_END}\n" + f"{self.IM_START}assistant\n" + ) + assert style._messages_to_prompt(messages) == expected + + def test_message_with_none_content(self, style: ChatMLPromptStyle) -> None: + # ChatML typically includes messages even with empty content (None becomes empty string). + messages = [ + _message(MessageRole.USER, "Hello"), + _message(MessageRole.ASSISTANT, None), # None content + _message(MessageRole.USER, "Follow up") + ] + expected = ( + f"{self.IM_START}user\nHello{self.IM_END}\n" + f"{self.IM_START}assistant\n{self.IM_END}\n" # Empty content for assistant + f"{self.IM_START}user\nFollow up{self.IM_END}\n" + f"{self.IM_START}assistant\n" + ) + assert style._messages_to_prompt(messages) == expected + + def test_correct_token_usage_and_newlines(self, style: ChatMLPromptStyle) -> None: + # Validates: <|im_start|>role\ncontent<|im_end|>\n ... <|im_start|>assistant\n + messages = [_message(MessageRole.USER, "Test")] + expected = ( + f"{self.IM_START}user\nTest{self.IM_END}\n" + f"{self.IM_START}assistant\n" + ) + actual = style._messages_to_prompt(messages) + assert actual == expected + assert actual.count(self.IM_START) == 2 + assert actual.count(self.IM_END) == 1 + assert actual.endswith(f"{self.IM_START}assistant\n") + # Check newlines: after role, after content (before im_end), after im_end + # <|im_start|>user\nTest<|im_end|>\n<|im_start|>assistant\n + # Role is followed by \n. Content is on its own line implicitly. im_end is followed by \n. + # The structure f"{IM_START}{role}\n{content}{IM_END}\n" ensures this. + user_part = f"{self.IM_START}user\nTest{self.IM_END}\n" + assert user_part in actual + + messages_with_system = [ + _message(MessageRole.SYSTEM, "Sys"), + _message(MessageRole.USER, "Usr") + ] + expected_sys_usr = ( + f"{self.IM_START}system\nSys{self.IM_END}\n" + f"{self.IM_START}user\nUsr{self.IM_END}\n" + f"{self.IM_START}assistant\n" + ) + actual_sys_usr = style._messages_to_prompt(messages_with_system) + assert actual_sys_usr == expected_sys_usr + assert actual_sys_usr.count(self.IM_START) == 3 + assert actual_sys_usr.count(self.IM_END) == 2 + + +class TestTagPromptStyle: + BOS = "" + EOS = "" + + @pytest.fixture + def style(self) -> TagPromptStyle: + return TagPromptStyle() + + def test_empty_messages(self, style: TagPromptStyle) -> None: + messages = [] + # Expected based on current TagPromptStyle: "<|assistant|>: " + expected = f"{self.BOS}<|assistant|>: " + assert style._messages_to_prompt(messages) == expected + + def test_simple_user_assistant_chat(self, style: TagPromptStyle) -> None: + messages = [ + _message(MessageRole.USER, "Hello"), + _message(MessageRole.ASSISTANT, "Hi there!"), + ] + expected = ( + f"{self.BOS}<|user|>: Hello\n" + f"<|assistant|>: Hi there!{self.EOS}\n" + f"<|assistant|>: " + ) + assert style._messages_to_prompt(messages) == expected + + def test_with_system_message(self, style: TagPromptStyle) -> None: + messages = [ + _message(MessageRole.SYSTEM, "System instructions."), + _message(MessageRole.USER, "Ping"), + _message(MessageRole.ASSISTANT, "Pong"), + ] + expected = ( + f"{self.BOS}<|system|>: System instructions.\n" + f"<|user|>: Ping\n" + f"<|assistant|>: Pong{self.EOS}\n" + f"<|assistant|>: " + ) + assert style._messages_to_prompt(messages) == expected + + def test_completion_to_prompt(self, style: TagPromptStyle) -> None: + completion = "Test user input" + # Expected: <|user|>: Test user input\n<|assistant|>: + expected = f"{self.BOS}<|user|>: Test user input\n<|assistant|>: " + assert style._completion_to_prompt(completion) == expected + + def test_bos_eos_placement_multiple_turns(self, style: TagPromptStyle) -> None: + messages = [ + _message(MessageRole.USER, "User1"), + _message(MessageRole.ASSISTANT, "Assistant1"), + _message(MessageRole.USER, "User2"), + _message(MessageRole.ASSISTANT, "Assistant2"), + ] + expected = ( + f"{self.BOS}<|user|>: User1\n" + f"<|assistant|>: Assistant1{self.EOS}\n" + f"<|user|>: User2\n" + f"<|assistant|>: Assistant2{self.EOS}\n" + f"<|assistant|>: " + ) + assert style._messages_to_prompt(messages) == expected + + def test_ending_with_user_message(self, style: TagPromptStyle) -> None: + messages = [ + _message(MessageRole.USER, "User1"), + _message(MessageRole.ASSISTANT, "Assistant1"), + _message(MessageRole.USER, "User2 (prompting for response)"), + ] + expected = ( + f"{self.BOS}<|user|>: User1\n" + f"<|assistant|>: Assistant1{self.EOS}\n" + f"<|user|>: User2 (prompting for response)\n" + f"<|assistant|>: " + ) + assert style._messages_to_prompt(messages) == expected + + def test_message_with_empty_content(self, style: TagPromptStyle) -> None: + messages = [ + _message(MessageRole.USER, ""), + _message(MessageRole.ASSISTANT, ""), # Empty assistant response + ] + # Content is stripped, so empty string remains empty. + expected = ( + f"{self.BOS}<|user|>: \n" + f"<|assistant|>: {self.EOS}\n" + f"<|assistant|>: " + ) + assert style._messages_to_prompt(messages) == expected + + def test_message_with_none_content(self, style: TagPromptStyle) -> None: + messages = [ + _message(MessageRole.USER, None), # Becomes empty string + _message(MessageRole.ASSISTANT, None), # Becomes empty string + ] + expected = ( + f"{self.BOS}<|user|>: \n" + f"<|assistant|>: {self.EOS}\n" + f"<|assistant|>: " + ) + assert style._messages_to_prompt(messages) == expected + + def test_only_user_message(self, style: TagPromptStyle) -> None: + messages = [ + _message(MessageRole.USER, "Just a user message"), + ] + expected = ( + f"{self.BOS}<|user|>: Just a user message\n" + f"<|assistant|>: " + ) + assert style._messages_to_prompt(messages) == expected + + def test_only_assistant_message(self, style: TagPromptStyle) -> None: + # This is an unusual case, but the style should handle it. + messages = [ + _message(MessageRole.ASSISTANT, "Only assistant"), + ] + expected = ( + f"{self.BOS}<|assistant|>: Only assistant{self.EOS}\n" + f"<|assistant|>: " # Still prompts for assistant + ) + assert style._messages_to_prompt(messages) == expected + + def test_only_system_message(self, style: TagPromptStyle) -> None: + messages = [ + _message(MessageRole.SYSTEM, "Only system"), + ] + expected = ( + f"{self.BOS}<|system|>: Only system\n" + f"<|assistant|>: " + ) + assert style._messages_to_prompt(messages) == expected