mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-06-20 04:32:27 +00:00
Merge 6fd3a23daf
into b7ee43788d
This commit is contained in:
commit
1df0fceb4d
@ -170,36 +170,53 @@ class Llama3PromptStyle(AbstractPromptStyle):
|
||||
"""
|
||||
|
||||
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||
prompt = ""
|
||||
prompt = self.BOS # Start with BOS token
|
||||
has_system_message = False
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
if not message or message.content is None:
|
||||
continue
|
||||
|
||||
if message.role == MessageRole.SYSTEM:
|
||||
prompt += f"{self.B_SYS}\n\n{message.content.strip()}{self.E_SYS}"
|
||||
prompt += f"{self.B_SYS}\n\n{message.content.strip()}{self.EOT}" # Use EOT for system message
|
||||
has_system_message = True
|
||||
elif message.role == MessageRole.USER:
|
||||
prompt += f"{self.B_INST}user{self.E_INST}\n\n{message.content.strip()}{self.EOT}"
|
||||
elif message.role == MessageRole.ASSISTANT:
|
||||
# Check if this is a tool call
|
||||
if message.additional_kwargs and message.additional_kwargs.get("type") == "tool_call":
|
||||
tool_call_content = message.content
|
||||
prompt += f"{self.B_INST}tool_code{self.E_INST}\n\n{tool_call_content}{self.EOT}"
|
||||
else:
|
||||
prompt += f"{self.ASSISTANT_INST}\n\n{message.content.strip()}{self.EOT}"
|
||||
elif message.role == MessageRole.TOOL:
|
||||
# Assuming additional_kwargs['type'] == 'tool_result'
|
||||
# and message.content contains the result of the tool call
|
||||
tool_result_content = message.content
|
||||
prompt += f"{self.B_INST}tool_output{self.E_INST}\n\n{tool_result_content}{self.EOT}"
|
||||
else:
|
||||
# Fallback for unknown roles (though ideally all roles should be handled)
|
||||
role_header = f"{self.B_INST}{message.role.value}{self.E_INST}"
|
||||
prompt += f"{role_header}\n\n{message.content.strip()}{self.EOT}"
|
||||
|
||||
# Add assistant header if the last message is not from the assistant
|
||||
if i == len(messages) - 1 and message.role != MessageRole.ASSISTANT:
|
||||
prompt += f"{self.ASSISTANT_INST}\n\n"
|
||||
|
||||
# Add default system prompt if no system message was provided
|
||||
# Add default system prompt if no system message was provided at the beginning
|
||||
if not has_system_message:
|
||||
prompt = (
|
||||
f"{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.E_SYS}" + prompt
|
||||
)
|
||||
default_system_prompt_str = f"{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT.strip()}{self.EOT}"
|
||||
prompt = self.BOS + default_system_prompt_str + prompt[len(self.BOS):] # Insert after BOS
|
||||
|
||||
# TODO: Implement tool handling logic
|
||||
# Add assistant header if the model should generate a response
|
||||
# This is typically when the last message is not from the assistant,
|
||||
# or when the last message is a tool result.
|
||||
if messages and (messages[-1].role != MessageRole.ASSISTANT or
|
||||
(messages[-1].role == MessageRole.TOOL)): # If last message was tool result
|
||||
prompt += f"{self.ASSISTANT_INST}\n\n"
|
||||
|
||||
return prompt
|
||||
|
||||
def _completion_to_prompt(self, completion: str) -> str:
|
||||
# Ensure BOS is at the start, followed by system prompt, then user message, then assistant prompt
|
||||
return (
|
||||
f"{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.E_SYS}"
|
||||
f"{self.BOS}{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT.strip()}{self.EOT}"
|
||||
f"{self.B_INST}user{self.E_INST}\n\n{completion.strip()}{self.EOT}"
|
||||
f"{self.ASSISTANT_INST}\n\n"
|
||||
)
|
||||
@ -213,49 +230,88 @@ class TagPromptStyle(AbstractPromptStyle):
|
||||
<|system|>: your system prompt here.
|
||||
<|user|>: user message here
|
||||
(possibly with context and question)
|
||||
<|assistant|>: assistant (model) response here.
|
||||
<|assistant|>: assistant (model) response here.</s>
|
||||
```
|
||||
|
||||
FIXME: should we add surrounding `<s>` and `</s>` tags, like in llama2?
|
||||
"""
|
||||
|
||||
BOS, EOS = "<s>", "</s>"
|
||||
|
||||
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||
"""Format message to prompt with `<|ROLE|>: MSG` style."""
|
||||
prompt = ""
|
||||
"""Format message to prompt with `<|ROLE|>: MSG` style, including BOS/EOS."""
|
||||
prompt_parts = []
|
||||
for message in messages:
|
||||
role = message.role
|
||||
content = message.content or ""
|
||||
message_from_user = f"<|{role.lower()}|>: {content.strip()}"
|
||||
message_from_user += "\n"
|
||||
prompt += message_from_user
|
||||
# we are missing the last <|assistant|> tag that will trigger a completion
|
||||
role_str = str(message.role).lower()
|
||||
content_str = str(message.content).strip() if message.content else ""
|
||||
|
||||
formatted_message = f"<|{role_str}|>: {content_str}"
|
||||
if message.role == MessageRole.ASSISTANT:
|
||||
formatted_message += self.EOS # EOS after assistant's message
|
||||
prompt_parts.append(formatted_message)
|
||||
|
||||
if not messages:
|
||||
# If there are no messages, start with BOS and prompt for assistant.
|
||||
# This assumes the typical case where the user would initiate.
|
||||
# _completion_to_prompt handles the user-initiated start.
|
||||
# If system is to start, a system message should be in `messages`.
|
||||
# So, if messages is empty, it implies we want to prompt for an assistant response
|
||||
# to an implicit (or empty) user turn.
|
||||
return f"{self.BOS}<|assistant|>: "
|
||||
|
||||
# Join messages with newline, start with BOS
|
||||
prompt = self.BOS + "\n".join(prompt_parts)
|
||||
|
||||
# Always end with a prompt for the assistant to speak, ensure it's on a new line
|
||||
if not prompt.endswith("\n"):
|
||||
prompt += "\n"
|
||||
prompt += "<|assistant|>: "
|
||||
return prompt
|
||||
|
||||
def _completion_to_prompt(self, completion: str) -> str:
|
||||
return self._messages_to_prompt(
|
||||
[ChatMessage(content=completion, role=MessageRole.USER)]
|
||||
)
|
||||
# A completion is a user message.
|
||||
# Format: <s><|user|>: {completion_content}\n<|assistant|>:
|
||||
content_str = str(completion).strip()
|
||||
return f"{self.BOS}<|user|>: {content_str}\n<|assistant|>: "
|
||||
|
||||
|
||||
class MistralPromptStyle(AbstractPromptStyle):
|
||||
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||
inst_buffer = []
|
||||
text = ""
|
||||
for message in messages:
|
||||
if message.role == MessageRole.SYSTEM or message.role == MessageRole.USER:
|
||||
inst_buffer.append(str(message.content).strip())
|
||||
prompt = ""
|
||||
current_instruction_parts = []
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
content = str(message.content).strip() if message.content else ""
|
||||
# Skip empty non-assistant messages. Assistant messages can be empty (e.g. for function calling).
|
||||
if not content and message.role != MessageRole.ASSISTANT:
|
||||
logger.debug("MistralPromptStyle: Skipping empty non-assistant message.")
|
||||
continue
|
||||
|
||||
if message.role == MessageRole.USER or message.role == MessageRole.SYSTEM:
|
||||
current_instruction_parts.append(content)
|
||||
elif message.role == MessageRole.ASSISTANT:
|
||||
text += "<s>[INST] " + "\n".join(inst_buffer) + " [/INST]"
|
||||
text += " " + str(message.content).strip() + "</s>"
|
||||
inst_buffer.clear()
|
||||
if not current_instruction_parts and i == 0:
|
||||
# First message is assistant, skip.
|
||||
logger.warning(
|
||||
"MistralPromptStyle: First message is from assistant, skipping."
|
||||
)
|
||||
continue
|
||||
if current_instruction_parts:
|
||||
# Only add <s> if prompt is empty, otherwise, assistant responses follow user turns.
|
||||
bos_token = "<s>" if not prompt else ""
|
||||
prompt += bos_token + "[INST] " + "\n".join(current_instruction_parts) + " [/INST]"
|
||||
current_instruction_parts = []
|
||||
# Assistant content can be empty, e.g. for tool calls that will be handled later
|
||||
prompt += " " + content + "</s>"
|
||||
else:
|
||||
raise ValueError(f"Unknown message role {message.role}")
|
||||
logger.warning(
|
||||
f"MistralPromptStyle: Unknown message role {message.role} encountered. Skipping."
|
||||
)
|
||||
|
||||
if len(inst_buffer) > 0:
|
||||
text += "<s>[INST] " + "\n".join(inst_buffer) + " [/INST]"
|
||||
# If there are pending instructions (i.e., last message was user/system)
|
||||
if current_instruction_parts:
|
||||
bos_token = "<s>" if not prompt else ""
|
||||
prompt += bos_token + "[INST] " + "\n".join(current_instruction_parts) + " [/INST]"
|
||||
|
||||
return text
|
||||
return prompt
|
||||
|
||||
def _completion_to_prompt(self, completion: str) -> str:
|
||||
return self._messages_to_prompt(
|
||||
@ -265,17 +321,27 @@ class MistralPromptStyle(AbstractPromptStyle):
|
||||
|
||||
class ChatMLPromptStyle(AbstractPromptStyle):
|
||||
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||
prompt = "<|im_start|>system\n"
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
role = message.role
|
||||
content = message.content or ""
|
||||
if role.lower() == "system":
|
||||
message_from_user = f"{content.strip()}"
|
||||
prompt += message_from_user
|
||||
elif role.lower() == "user":
|
||||
prompt += "<|im_end|>\n<|im_start|>user\n"
|
||||
message_from_user = f"{content.strip()}<|im_end|>\n"
|
||||
prompt += message_from_user
|
||||
role = str(message.role).lower() # Ensure role is a string and lowercase
|
||||
content = str(message.content).strip() if message.content else ""
|
||||
|
||||
# According to the ChatML documentation, messages are formatted as:
|
||||
# <|im_start|>role_name
|
||||
# content
|
||||
# <|im_end|>
|
||||
# There should be a newline after role_name and before <|im_end|>.
|
||||
# And a newline after <|im_end|> to separate messages.
|
||||
|
||||
# Skip empty messages if content is crucial.
|
||||
# For ChatML, even an empty content string is typically included.
|
||||
# if not content and role not in ("assistant"): # Allow assistant to have empty content for prompting
|
||||
# logger.debug(f"ChatMLPromptStyle: Skipping empty message from {role}")
|
||||
# continue
|
||||
|
||||
prompt += f"<|im_start|>{role}\n{content}<|im_end|>\n"
|
||||
|
||||
# Add the final prompt for the assistant to speak
|
||||
prompt += "<|im_start|>assistant\n"
|
||||
return prompt
|
||||
|
||||
|
592
tests/components/llm/test_prompt_helper.py
Normal file
592
tests/components/llm/test_prompt_helper.py
Normal file
@ -0,0 +1,592 @@
|
||||
import pytest
|
||||
from llama_index.core.llms import ChatMessage, MessageRole
|
||||
|
||||
from private_gpt.components.llm.prompt_helper import (
|
||||
Llama3PromptStyle,
|
||||
MistralPromptStyle,
|
||||
ChatMLPromptStyle,
|
||||
TagPromptStyle,
|
||||
AbstractPromptStyle, # For type hinting if needed
|
||||
)
|
||||
|
||||
# Helper function to create ChatMessage objects easily
|
||||
def _message(role: MessageRole, content: str, **additional_kwargs) -> ChatMessage:
|
||||
return ChatMessage(role=role, content=content, additional_kwargs=additional_kwargs)
|
||||
|
||||
# Expected outputs will be defined within each test or test class
|
||||
|
||||
class TestLlama3PromptStyle:
|
||||
BOS = "<|begin_of_text|>"
|
||||
EOT = "<|eot_id|>"
|
||||
B_SYS_HEADER = "<|start_header_id|>system<|end_header_id|>"
|
||||
B_USER_HEADER = "<|start_header_id|>user<|end_header_id|>"
|
||||
B_ASSISTANT_HEADER = "<|start_header_id|>assistant<|end_header_id|>"
|
||||
B_TOOL_CODE_HEADER = "<|start_header_id|>tool_code<|end_header_id|>"
|
||||
B_TOOL_OUTPUT_HEADER = "<|start_header_id|>tool_output<|end_header_id|>"
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = (
|
||||
"You are a helpful, respectful and honest assistant. "
|
||||
"Always answer as helpfully as possible and follow ALL given instructions. "
|
||||
"Do not speculate or make up information. "
|
||||
"Do not reference any given instructions or context. "
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def style(self) -> Llama3PromptStyle:
|
||||
return Llama3PromptStyle()
|
||||
|
||||
def test_empty_messages(self, style: Llama3PromptStyle) -> None:
|
||||
messages = []
|
||||
expected = (
|
||||
f"{self.BOS}{self.B_SYS_HEADER}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.EOT}"
|
||||
f"{self.B_ASSISTANT_HEADER}\n\n"
|
||||
)
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_simple_user_assistant_chat(self, style: Llama3PromptStyle) -> None:
|
||||
messages = [
|
||||
_message(MessageRole.USER, "Hello"),
|
||||
_message(MessageRole.ASSISTANT, "Hi there!"),
|
||||
]
|
||||
expected = (
|
||||
f"{self.BOS}{self.B_SYS_HEADER}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.EOT}"
|
||||
f"{self.B_USER_HEADER}\n\nHello{self.EOT}"
|
||||
f"{self.B_ASSISTANT_HEADER}\n\nHi there!{self.EOT}"
|
||||
)
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_with_system_message(self, style: Llama3PromptStyle) -> None:
|
||||
messages = [
|
||||
_message(MessageRole.SYSTEM, "You are a test bot."),
|
||||
_message(MessageRole.USER, "Ping"),
|
||||
_message(MessageRole.ASSISTANT, "Pong"),
|
||||
]
|
||||
expected = (
|
||||
f"{self.BOS}{self.B_SYS_HEADER}\n\nYou are a test bot.{self.EOT}"
|
||||
f"{self.B_USER_HEADER}\n\nPing{self.EOT}"
|
||||
f"{self.B_ASSISTANT_HEADER}\n\nPong{self.EOT}"
|
||||
)
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_completion_to_prompt(self, style: Llama3PromptStyle) -> None:
|
||||
completion = "Test completion"
|
||||
expected = (
|
||||
f"{self.BOS}{self.B_SYS_HEADER}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.EOT}"
|
||||
f"{self.B_USER_HEADER}\n\nTest completion{self.EOT}"
|
||||
f"{self.B_ASSISTANT_HEADER}\n\n"
|
||||
)
|
||||
assert style._completion_to_prompt(completion) == expected
|
||||
|
||||
def test_tool_call_and_result(self, style: Llama3PromptStyle) -> None:
|
||||
messages = [
|
||||
_message(MessageRole.USER, "What's the weather in Paris?"),
|
||||
_message(
|
||||
MessageRole.ASSISTANT,
|
||||
content=None, # LlamaIndex might put tool call details here, or just use additional_kwargs
|
||||
additional_kwargs={"type": "tool_call", "tool_call_id": "123", "name": "get_weather", "arguments": '{"location": "Paris"}'}
|
||||
),
|
||||
_message(
|
||||
MessageRole.TOOL,
|
||||
content='{"temperature": "20C"}',
|
||||
additional_kwargs={"type": "tool_result", "tool_call_id": "123", "name": "get_weather"}
|
||||
),
|
||||
]
|
||||
# Note: The current Llama3PromptStyle implementation uses message.content for tool call/result content.
|
||||
# If additional_kwargs are used to structure tool calls (like OpenAI), the style needs to be adapted.
|
||||
# For this test, we assume content holds the direct string for tool_code and tool_output.
|
||||
# Let's adjust the messages based on current implementation that uses .content for tool_code/output
|
||||
messages_for_current_impl = [
|
||||
_message(MessageRole.USER, "What's the weather in Paris?"),
|
||||
_message(
|
||||
MessageRole.ASSISTANT,
|
||||
content='get_weather({"location": "Paris"})', # Simplified tool call content
|
||||
additional_kwargs={"type": "tool_call"}
|
||||
),
|
||||
_message(
|
||||
MessageRole.TOOL,
|
||||
content='{"temperature": "20C"}',
|
||||
additional_kwargs={"type": "tool_result"} # No specific tool_call_id or name used by current style from additional_kwargs
|
||||
),
|
||||
]
|
||||
expected = (
|
||||
f"{self.BOS}{self.B_SYS_HEADER}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.EOT}"
|
||||
f"{self.B_USER_HEADER}\n\nWhat's the weather in Paris?{self.EOT}"
|
||||
f"{self.B_TOOL_CODE_HEADER}\n\nget_weather({{\"location\": \"Paris\"}}){self.EOT}" # Content is 'get_weather(...)'
|
||||
f"{self.B_TOOL_OUTPUT_HEADER}\n\n{{\"temperature\": \"20C\"}}{self.EOT}"
|
||||
f"{self.B_ASSISTANT_HEADER}\n\n" # Assistant should respond after tool result
|
||||
)
|
||||
assert style._messages_to_prompt(messages_for_current_impl) == expected
|
||||
|
||||
def test_multiple_interactions_with_tools(self, style: Llama3PromptStyle) -> None:
|
||||
messages = [
|
||||
_message(MessageRole.USER, "Can you search for prompt engineering techniques?"),
|
||||
_message(MessageRole.ASSISTANT, content="Okay, I will search for that.", additional_kwargs={}), # Normal assistant message
|
||||
_message(MessageRole.ASSISTANT, content='search_web({"query": "prompt engineering techniques"})', additional_kwargs={"type": "tool_call"}),
|
||||
_message(MessageRole.TOOL, content='[Result 1: ...]', additional_kwargs={"type": "tool_result"}),
|
||||
_message(MessageRole.ASSISTANT, content="I found one result. Should I look for more?", additional_kwargs={}),
|
||||
_message(MessageRole.USER, "Yes, please find another one."),
|
||||
]
|
||||
expected = (
|
||||
f"{self.BOS}{self.B_SYS_HEADER}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.EOT}"
|
||||
f"{self.B_USER_HEADER}\n\nCan you search for prompt engineering techniques?{self.EOT}"
|
||||
f"{self.B_ASSISTANT_HEADER}\n\nOkay, I will search for that.{self.EOT}"
|
||||
f"{self.B_TOOL_CODE_HEADER}\n\nsearch_web({{\"query\": \"prompt engineering techniques\"}}){self.EOT}"
|
||||
f"{self.B_TOOL_OUTPUT_HEADER}\n\n[Result 1: ...]{self.EOT}"
|
||||
f"{self.B_ASSISTANT_HEADER}\n\nI found one result. Should I look for more?{self.EOT}"
|
||||
f"{self.B_USER_HEADER}\n\nYes, please find another one.{self.EOT}"
|
||||
f"{self.B_ASSISTANT_HEADER}\n\n"
|
||||
)
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_ending_with_user_message(self, style: Llama3PromptStyle) -> None:
|
||||
messages = [
|
||||
_message(MessageRole.USER, "First message"),
|
||||
_message(MessageRole.ASSISTANT, "First response"),
|
||||
_message(MessageRole.USER, "Second message, expecting response"),
|
||||
]
|
||||
expected = (
|
||||
f"{self.BOS}{self.B_SYS_HEADER}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.EOT}"
|
||||
f"{self.B_USER_HEADER}\n\nFirst message{self.EOT}"
|
||||
f"{self.B_ASSISTANT_HEADER}\n\nFirst response{self.EOT}"
|
||||
f"{self.B_USER_HEADER}\n\nSecond message, expecting response{self.EOT}"
|
||||
f"{self.B_ASSISTANT_HEADER}\n\n"
|
||||
)
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_ending_with_tool_result(self, style: Llama3PromptStyle) -> None:
|
||||
messages = [
|
||||
_message(MessageRole.USER, "Find info on X."),
|
||||
_message(MessageRole.ASSISTANT, content='search({"topic": "X"})', additional_kwargs={"type": "tool_call"}),
|
||||
_message(MessageRole.TOOL, content="Info about X found.", additional_kwargs={"type": "tool_result"}),
|
||||
]
|
||||
expected = (
|
||||
f"{self.BOS}{self.B_SYS_HEADER}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.EOT}"
|
||||
f"{self.B_USER_HEADER}\n\nFind info on X.{self.EOT}"
|
||||
f"{self.B_TOOL_CODE_HEADER}\n\nsearch({{\"topic\": \"X\"}}){self.EOT}"
|
||||
f"{self.B_TOOL_OUTPUT_HEADER}\n\nInfo about X found.{self.EOT}"
|
||||
f"{self.B_ASSISTANT_HEADER}\n\n"
|
||||
)
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_message_with_empty_content(self, style: Llama3PromptStyle) -> None:
|
||||
# Llama3PromptStyle skips messages with None content, but not necessarily empty string.
|
||||
# Let's test with an empty string for user, and None for assistant (which should be skipped)
|
||||
messages = [
|
||||
_message(MessageRole.USER, ""), # Empty string content
|
||||
_message(MessageRole.ASSISTANT, None), # None content, should be skipped
|
||||
_message(MessageRole.USER, "Follow up")
|
||||
]
|
||||
expected = (
|
||||
f"{self.BOS}{self.B_SYS_HEADER}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.EOT}"
|
||||
f"{self.B_USER_HEADER}\n\n{self.EOT}" # Empty content for user
|
||||
f"{self.B_USER_HEADER}\n\nFollow up{self.EOT}" # Assistant message was skipped
|
||||
f"{self.B_ASSISTANT_HEADER}\n\n"
|
||||
)
|
||||
# The style's loop: `if not message or message.content is None: continue`
|
||||
# An empty string `""` is not `None`, so it should be included.
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_system_message_not_first(self, style: Llama3PromptStyle) -> None:
|
||||
messages = [
|
||||
_message(MessageRole.USER, "Hello"),
|
||||
_message(MessageRole.SYSTEM, "System message in the middle (unusual)."),
|
||||
_message(MessageRole.ASSISTANT, "Hi there!"),
|
||||
]
|
||||
# The current implementation processes system messages whenever they appear.
|
||||
# If a system message appears, it sets `has_system_message = True`.
|
||||
# If NO system message appears, a default one is prepended.
|
||||
# If one DOES appear, it's used, and default is not prepended.
|
||||
expected = (
|
||||
f"{self.BOS}"
|
||||
# Default system prompt is NOT added because a system message IS present.
|
||||
f"{self.B_USER_HEADER}\n\nHello{self.EOT}"
|
||||
f"{self.B_SYS_HEADER}\n\nSystem message in the middle (unusual).{self.EOT}"
|
||||
f"{self.B_ASSISTANT_HEADER}\n\nHi there!{self.EOT}"
|
||||
)
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
|
||||
class TestMistralPromptStyle:
|
||||
@pytest.fixture
|
||||
def style(self) -> MistralPromptStyle:
|
||||
return MistralPromptStyle()
|
||||
|
||||
def test_empty_messages(self, style: MistralPromptStyle) -> None:
|
||||
messages = []
|
||||
# The refactored version should produce an empty string if no instructions are pending.
|
||||
# Or, if it were to prompt for something, it might be "<s>[INST] [/INST]" or just ""
|
||||
# Based on current refactored code: if current_instruction_parts is empty, it returns prompt ("")
|
||||
assert style._messages_to_prompt(messages) == ""
|
||||
|
||||
def test_simple_user_assistant_chat(self, style: MistralPromptStyle) -> None:
|
||||
messages = [
|
||||
_message(MessageRole.USER, "Hello"),
|
||||
_message(MessageRole.ASSISTANT, "Hi there!"),
|
||||
]
|
||||
expected = "<s>[INST] Hello [/INST] Hi there!</s>"
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_with_system_message(self, style: MistralPromptStyle) -> None:
|
||||
# System messages are treated like user messages in the current Mistral impl
|
||||
messages = [
|
||||
_message(MessageRole.SYSTEM, "You are helpful."),
|
||||
_message(MessageRole.USER, "Ping"),
|
||||
_message(MessageRole.ASSISTANT, "Pong"),
|
||||
]
|
||||
expected = "<s>[INST] You are helpful.\nPing [/INST] Pong</s>"
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_completion_to_prompt(self, style: MistralPromptStyle) -> None:
|
||||
completion = "Test completion"
|
||||
# This will call _messages_to_prompt with [ChatMessage(role=USER, content="Test completion")]
|
||||
expected = "<s>[INST] Test completion [/INST]"
|
||||
assert style._completion_to_prompt(completion) == expected
|
||||
|
||||
def test_consecutive_user_messages(self, style: MistralPromptStyle) -> None:
|
||||
messages = [
|
||||
_message(MessageRole.USER, "First part."),
|
||||
_message(MessageRole.USER, "Second part."),
|
||||
_message(MessageRole.ASSISTANT, "Understood."),
|
||||
]
|
||||
expected = "<s>[INST] First part.\nSecond part. [/INST] Understood.</s>"
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_ending_with_user_message(self, style: MistralPromptStyle) -> None:
|
||||
messages = [
|
||||
_message(MessageRole.USER, "Hi"),
|
||||
_message(MessageRole.ASSISTANT, "Hello"),
|
||||
_message(MessageRole.USER, "How are you?"),
|
||||
]
|
||||
# Note: The previous prompt had "<s>[INST] Hi [/INST] Hello</s>"
|
||||
# The new user message should start a new <s>[INST] block if prompt was not empty.
|
||||
# Current logic: bos_token = "<s>" if not prompt else ""
|
||||
# Since prompt is not empty after "Hello</s>", bos_token will be "".
|
||||
# This might be an issue with the current Mistral refactor if strict BOS per turn is needed.
|
||||
# The current refactored code for Mistral:
|
||||
# prompt += bos_token + "[INST] " + "\n".join(current_instruction_parts) + " [/INST]"
|
||||
# If prompt = "<s>[INST] Hi [/INST] Hello</s>", then bos_token is "", so it becomes:
|
||||
# "<s>[INST] Hi [/INST] Hello</s>[INST] How are you? [/INST]" -> This seems correct for continued conversation.
|
||||
expected = "<s>[INST] Hi [/INST] Hello</s>[INST] How are you? [/INST]"
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_initial_assistant_message_skipped(self, style: MistralPromptStyle, caplog) -> None:
|
||||
messages = [
|
||||
_message(MessageRole.ASSISTANT, "I speak first!"),
|
||||
_message(MessageRole.USER, "Oh, hello there."),
|
||||
]
|
||||
# The first assistant message should be skipped with a warning.
|
||||
# The prompt should then start with the user message.
|
||||
expected = "<s>[INST] Oh, hello there. [/INST]"
|
||||
with caplog.at_level("WARNING"):
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
assert "MistralPromptStyle: First message is from assistant, skipping." in caplog.text
|
||||
|
||||
def test_multiple_assistant_messages_in_a_row(self, style: MistralPromptStyle) -> None:
|
||||
messages = [
|
||||
_message(MessageRole.USER, "User message"),
|
||||
_message(MessageRole.ASSISTANT, "Assistant first response."),
|
||||
_message(MessageRole.ASSISTANT, "Assistant second response (after no user message)."),
|
||||
]
|
||||
# current_instruction_parts will be empty when processing the second assistant message.
|
||||
# The logic is:
|
||||
# if current_instruction_parts: prompt += bos_token + "[INST] " + "\n".join(current_instruction_parts) + " [/INST]"
|
||||
# prompt += " " + content + "</s>"
|
||||
# So, it will correctly append the second assistant message without a new [INST]
|
||||
expected = ("<s>[INST] User message [/INST] Assistant first response.</s>"
|
||||
" Assistant second response (after no user message).</s>")
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_system_user_assistant_alternating(self, style: MistralPromptStyle) -> None:
|
||||
messages = [
|
||||
_message(MessageRole.SYSTEM, "System setup."),
|
||||
_message(MessageRole.USER, "User query 1."),
|
||||
_message(MessageRole.ASSISTANT, "Assistant answer 1."),
|
||||
_message(MessageRole.USER, "User query 2."), # System messages are part of INST with user
|
||||
_message(MessageRole.ASSISTANT, "Assistant answer 2."),
|
||||
]
|
||||
expected = ("<s>[INST] System setup.\nUser query 1. [/INST] Assistant answer 1.</s>"
|
||||
"[INST] User query 2. [/INST] Assistant answer 2.</s>")
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_empty_content_messages(self, style: MistralPromptStyle, caplog) -> None:
|
||||
messages = [
|
||||
_message(MessageRole.USER, "Hello"),
|
||||
_message(MessageRole.USER, None), # Skipped by `if not content and message.role != MessageRole.ASSISTANT:`
|
||||
_message(MessageRole.USER, ""), # Skipped by the same logic
|
||||
_message(MessageRole.ASSISTANT, "Hi"),
|
||||
_message(MessageRole.ASSISTANT, ""), # Empty assistant message, kept
|
||||
_message(MessageRole.ASSISTANT, None),# Empty assistant message, kept (content becomes "")
|
||||
]
|
||||
# The refactored code skips empty non-assistant messages.
|
||||
# Empty assistant messages (content="" or content=None) are kept.
|
||||
expected = ("<s>[INST] Hello [/INST] Hi</s>"
|
||||
" </s>" # From assistant with content=""
|
||||
" </s>") # From assistant with content=None (becomes "")
|
||||
|
||||
with caplog.at_level("DEBUG"): # The skipping messages are logged at DEBUG level
|
||||
actual = style._messages_to_prompt(messages)
|
||||
assert actual == expected
|
||||
# Check that specific debug messages for skipping are present
|
||||
assert "Skipping empty non-assistant message." in caplog.text # For the None and "" user messages
|
||||
|
||||
|
||||
class TestChatMLPromptStyle:
|
||||
IM_START = "<|im_start|>"
|
||||
IM_END = "<|im_end|>"
|
||||
|
||||
@pytest.fixture
|
||||
def style(self) -> ChatMLPromptStyle:
|
||||
return ChatMLPromptStyle()
|
||||
|
||||
def test_empty_messages(self, style: ChatMLPromptStyle) -> None:
|
||||
messages = []
|
||||
# Expected: just the final assistant prompt
|
||||
expected = f"{self.IM_START}assistant\n"
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_simple_user_assistant_chat(self, style: ChatMLPromptStyle) -> None:
|
||||
messages = [
|
||||
_message(MessageRole.USER, "Hello"),
|
||||
_message(MessageRole.ASSISTANT, "Hi there!"),
|
||||
]
|
||||
expected = (
|
||||
f"{self.IM_START}user\nHello{self.IM_END}\n"
|
||||
f"{self.IM_START}assistant\nHi there!{self.IM_END}\n"
|
||||
f"{self.IM_START}assistant\n"
|
||||
)
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_with_system_message(self, style: ChatMLPromptStyle) -> None:
|
||||
messages = [
|
||||
_message(MessageRole.SYSTEM, "You are ChatML bot."),
|
||||
_message(MessageRole.USER, "Ping"),
|
||||
_message(MessageRole.ASSISTANT, "Pong"),
|
||||
]
|
||||
expected = (
|
||||
f"{self.IM_START}system\nYou are ChatML bot.{self.IM_END}\n"
|
||||
f"{self.IM_START}user\nPing{self.IM_END}\n"
|
||||
f"{self.IM_START}assistant\nPong{self.IM_END}\n"
|
||||
f"{self.IM_START}assistant\n"
|
||||
)
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_completion_to_prompt(self, style: ChatMLPromptStyle) -> None:
|
||||
completion = "Test user input"
|
||||
# This will call _messages_to_prompt with [ChatMessage(role=USER, content="Test user input")]
|
||||
expected = (
|
||||
f"{self.IM_START}user\nTest user input{self.IM_END}\n"
|
||||
f"{self.IM_START}assistant\n"
|
||||
)
|
||||
assert style._completion_to_prompt(completion) == expected
|
||||
|
||||
def test_multiple_turns(self, style: ChatMLPromptStyle) -> None:
|
||||
messages = [
|
||||
_message(MessageRole.USER, "First user message."),
|
||||
_message(MessageRole.ASSISTANT, "First assistant response."),
|
||||
_message(MessageRole.USER, "Second user message."),
|
||||
_message(MessageRole.ASSISTANT, "Second assistant response."),
|
||||
]
|
||||
expected = (
|
||||
f"{self.IM_START}user\nFirst user message.{self.IM_END}\n"
|
||||
f"{self.IM_START}assistant\nFirst assistant response.{self.IM_END}\n"
|
||||
f"{self.IM_START}user\nSecond user message.{self.IM_END}\n"
|
||||
f"{self.IM_START}assistant\nSecond assistant response.{self.IM_END}\n"
|
||||
f"{self.IM_START}assistant\n"
|
||||
)
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_message_with_empty_content(self, style: ChatMLPromptStyle) -> None:
|
||||
# ChatML typically includes messages even with empty content.
|
||||
messages = [
|
||||
_message(MessageRole.USER, "Hello"),
|
||||
_message(MessageRole.ASSISTANT, ""), # Empty string content
|
||||
_message(MessageRole.USER, "Follow up")
|
||||
]
|
||||
expected = (
|
||||
f"{self.IM_START}user\nHello{self.IM_END}\n"
|
||||
f"{self.IM_START}assistant\n{self.IM_END}\n" # Empty content for assistant
|
||||
f"{self.IM_START}user\nFollow up{self.IM_END}\n"
|
||||
f"{self.IM_START}assistant\n"
|
||||
)
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_message_with_none_content(self, style: ChatMLPromptStyle) -> None:
|
||||
# ChatML typically includes messages even with empty content (None becomes empty string).
|
||||
messages = [
|
||||
_message(MessageRole.USER, "Hello"),
|
||||
_message(MessageRole.ASSISTANT, None), # None content
|
||||
_message(MessageRole.USER, "Follow up")
|
||||
]
|
||||
expected = (
|
||||
f"{self.IM_START}user\nHello{self.IM_END}\n"
|
||||
f"{self.IM_START}assistant\n{self.IM_END}\n" # Empty content for assistant
|
||||
f"{self.IM_START}user\nFollow up{self.IM_END}\n"
|
||||
f"{self.IM_START}assistant\n"
|
||||
)
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_correct_token_usage_and_newlines(self, style: ChatMLPromptStyle) -> None:
|
||||
# Validates: <|im_start|>role\ncontent<|im_end|>\n ... <|im_start|>assistant\n
|
||||
messages = [_message(MessageRole.USER, "Test")]
|
||||
expected = (
|
||||
f"{self.IM_START}user\nTest{self.IM_END}\n"
|
||||
f"{self.IM_START}assistant\n"
|
||||
)
|
||||
actual = style._messages_to_prompt(messages)
|
||||
assert actual == expected
|
||||
assert actual.count(self.IM_START) == 2
|
||||
assert actual.count(self.IM_END) == 1
|
||||
assert actual.endswith(f"{self.IM_START}assistant\n")
|
||||
# Check newlines: after role, after content (before im_end), after im_end
|
||||
# <|im_start|>user\nTest<|im_end|>\n<|im_start|>assistant\n
|
||||
# Role is followed by \n. Content is on its own line implicitly. im_end is followed by \n.
|
||||
# The structure f"{IM_START}{role}\n{content}{IM_END}\n" ensures this.
|
||||
user_part = f"{self.IM_START}user\nTest{self.IM_END}\n"
|
||||
assert user_part in actual
|
||||
|
||||
messages_with_system = [
|
||||
_message(MessageRole.SYSTEM, "Sys"),
|
||||
_message(MessageRole.USER, "Usr")
|
||||
]
|
||||
expected_sys_usr = (
|
||||
f"{self.IM_START}system\nSys{self.IM_END}\n"
|
||||
f"{self.IM_START}user\nUsr{self.IM_END}\n"
|
||||
f"{self.IM_START}assistant\n"
|
||||
)
|
||||
actual_sys_usr = style._messages_to_prompt(messages_with_system)
|
||||
assert actual_sys_usr == expected_sys_usr
|
||||
assert actual_sys_usr.count(self.IM_START) == 3
|
||||
assert actual_sys_usr.count(self.IM_END) == 2
|
||||
|
||||
|
||||
class TestTagPromptStyle:
|
||||
BOS = "<s>"
|
||||
EOS = "</s>"
|
||||
|
||||
@pytest.fixture
|
||||
def style(self) -> TagPromptStyle:
|
||||
return TagPromptStyle()
|
||||
|
||||
def test_empty_messages(self, style: TagPromptStyle) -> None:
|
||||
messages = []
|
||||
# Expected based on current TagPromptStyle: "<s><|assistant|>: "
|
||||
expected = f"{self.BOS}<|assistant|>: "
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_simple_user_assistant_chat(self, style: TagPromptStyle) -> None:
|
||||
messages = [
|
||||
_message(MessageRole.USER, "Hello"),
|
||||
_message(MessageRole.ASSISTANT, "Hi there!"),
|
||||
]
|
||||
expected = (
|
||||
f"{self.BOS}<|user|>: Hello\n"
|
||||
f"<|assistant|>: Hi there!{self.EOS}\n"
|
||||
f"<|assistant|>: "
|
||||
)
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_with_system_message(self, style: TagPromptStyle) -> None:
|
||||
messages = [
|
||||
_message(MessageRole.SYSTEM, "System instructions."),
|
||||
_message(MessageRole.USER, "Ping"),
|
||||
_message(MessageRole.ASSISTANT, "Pong"),
|
||||
]
|
||||
expected = (
|
||||
f"{self.BOS}<|system|>: System instructions.\n"
|
||||
f"<|user|>: Ping\n"
|
||||
f"<|assistant|>: Pong{self.EOS}\n"
|
||||
f"<|assistant|>: "
|
||||
)
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_completion_to_prompt(self, style: TagPromptStyle) -> None:
|
||||
completion = "Test user input"
|
||||
# Expected: <s><|user|>: Test user input\n<|assistant|>:
|
||||
expected = f"{self.BOS}<|user|>: Test user input\n<|assistant|>: "
|
||||
assert style._completion_to_prompt(completion) == expected
|
||||
|
||||
def test_bos_eos_placement_multiple_turns(self, style: TagPromptStyle) -> None:
|
||||
messages = [
|
||||
_message(MessageRole.USER, "User1"),
|
||||
_message(MessageRole.ASSISTANT, "Assistant1"),
|
||||
_message(MessageRole.USER, "User2"),
|
||||
_message(MessageRole.ASSISTANT, "Assistant2"),
|
||||
]
|
||||
expected = (
|
||||
f"{self.BOS}<|user|>: User1\n"
|
||||
f"<|assistant|>: Assistant1{self.EOS}\n"
|
||||
f"<|user|>: User2\n"
|
||||
f"<|assistant|>: Assistant2{self.EOS}\n"
|
||||
f"<|assistant|>: "
|
||||
)
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_ending_with_user_message(self, style: TagPromptStyle) -> None:
|
||||
messages = [
|
||||
_message(MessageRole.USER, "User1"),
|
||||
_message(MessageRole.ASSISTANT, "Assistant1"),
|
||||
_message(MessageRole.USER, "User2 (prompting for response)"),
|
||||
]
|
||||
expected = (
|
||||
f"{self.BOS}<|user|>: User1\n"
|
||||
f"<|assistant|>: Assistant1{self.EOS}\n"
|
||||
f"<|user|>: User2 (prompting for response)\n"
|
||||
f"<|assistant|>: "
|
||||
)
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_message_with_empty_content(self, style: TagPromptStyle) -> None:
|
||||
messages = [
|
||||
_message(MessageRole.USER, ""),
|
||||
_message(MessageRole.ASSISTANT, ""), # Empty assistant response
|
||||
]
|
||||
# Content is stripped, so empty string remains empty.
|
||||
expected = (
|
||||
f"{self.BOS}<|user|>: \n"
|
||||
f"<|assistant|>: {self.EOS}\n"
|
||||
f"<|assistant|>: "
|
||||
)
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_message_with_none_content(self, style: TagPromptStyle) -> None:
|
||||
messages = [
|
||||
_message(MessageRole.USER, None), # Becomes empty string
|
||||
_message(MessageRole.ASSISTANT, None), # Becomes empty string
|
||||
]
|
||||
expected = (
|
||||
f"{self.BOS}<|user|>: \n"
|
||||
f"<|assistant|>: {self.EOS}\n"
|
||||
f"<|assistant|>: "
|
||||
)
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_only_user_message(self, style: TagPromptStyle) -> None:
|
||||
messages = [
|
||||
_message(MessageRole.USER, "Just a user message"),
|
||||
]
|
||||
expected = (
|
||||
f"{self.BOS}<|user|>: Just a user message\n"
|
||||
f"<|assistant|>: "
|
||||
)
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_only_assistant_message(self, style: TagPromptStyle) -> None:
|
||||
# This is an unusual case, but the style should handle it.
|
||||
messages = [
|
||||
_message(MessageRole.ASSISTANT, "Only assistant"),
|
||||
]
|
||||
expected = (
|
||||
f"{self.BOS}<|assistant|>: Only assistant{self.EOS}\n"
|
||||
f"<|assistant|>: " # Still prompts for assistant
|
||||
)
|
||||
assert style._messages_to_prompt(messages) == expected
|
||||
|
||||
def test_only_system_message(self, style: TagPromptStyle) -> None:
|
||||
messages = [
|
||||
_message(MessageRole.SYSTEM, "Only system"),
|
||||
]
|
||||
expected = (
|
||||
f"{self.BOS}<|system|>: Only system\n"
|
||||
f"<|assistant|>: "
|
||||
)
|
||||
assert style._messages_to_prompt(messages) == expected
|
Loading…
Reference in New Issue
Block a user