mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-09-14 06:13:25 +00:00
feat: add mistral + chatml prompts (#1426)
This commit is contained in:
@@ -2,8 +2,10 @@ import pytest
|
||||
from llama_index.llms import ChatMessage, MessageRole
|
||||
|
||||
from private_gpt.components.llm.prompt_helper import (
|
||||
ChatMLPromptStyle,
|
||||
DefaultPromptStyle,
|
||||
Llama2PromptStyle,
|
||||
MistralPromptStyle,
|
||||
TagPromptStyle,
|
||||
get_prompt_style,
|
||||
)
|
||||
@@ -15,6 +17,8 @@ from private_gpt.components.llm.prompt_helper import (
|
||||
("default", DefaultPromptStyle),
|
||||
("llama2", Llama2PromptStyle),
|
||||
("tag", TagPromptStyle),
|
||||
("mistral", MistralPromptStyle),
|
||||
("chatml", ChatMLPromptStyle),
|
||||
],
|
||||
)
|
||||
def test_get_prompt_style_success(prompt_style, expected_prompt_style):
|
||||
@@ -62,6 +66,39 @@ def test_tag_prompt_style_format_with_system_prompt():
|
||||
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
||||
|
||||
|
||||
def test_mistral_prompt_style_format():
|
||||
prompt_style = MistralPromptStyle()
|
||||
messages = [
|
||||
ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
|
||||
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
|
||||
]
|
||||
|
||||
expected_prompt = (
|
||||
"<s>[INST] You are an AI assistant. [/INST]</s>"
|
||||
"[INST] Hello, how are you doing? [/INST]"
|
||||
)
|
||||
|
||||
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
||||
|
||||
|
||||
def test_chatml_prompt_style_format():
|
||||
prompt_style = ChatMLPromptStyle()
|
||||
messages = [
|
||||
ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
|
||||
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
|
||||
]
|
||||
|
||||
expected_prompt = (
|
||||
"<|im_start|>system\n"
|
||||
"You are an AI assistant.<|im_end|>\n"
|
||||
"<|im_start|>user\n"
|
||||
"Hello, how are you doing?<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
|
||||
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
||||
|
||||
|
||||
def test_llama2_prompt_style_format():
|
||||
prompt_style = Llama2PromptStyle()
|
||||
messages = [
|
||||
|
Reference in New Issue
Block a user