mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-09-14 14:19:08 +00:00
added llama3 prompt (#1962)
* added llama3 prompt * more fixes to pass tests; changed type VectorStore -> BasePydanticVectorStore, see https://github.com/run-llama/llama_index/blob/main/CHANGELOG.md#2024-05-14 * fix: new llama3 prompt --------- Co-authored-by: Javier Martinez <javiermartinezalvarez98@gmail.com>
This commit is contained in:
@@ -5,6 +5,7 @@ from private_gpt.components.llm.prompt_helper import (
|
||||
ChatMLPromptStyle,
|
||||
DefaultPromptStyle,
|
||||
Llama2PromptStyle,
|
||||
Llama3PromptStyle,
|
||||
MistralPromptStyle,
|
||||
TagPromptStyle,
|
||||
get_prompt_style,
|
||||
@@ -139,3 +140,57 @@ def test_llama2_prompt_style_with_system_prompt():
|
||||
)
|
||||
|
||||
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
||||
|
||||
|
||||
def test_llama3_prompt_style_format():
|
||||
prompt_style = Llama3PromptStyle()
|
||||
messages = [
|
||||
ChatMessage(content="You are a helpful assistant", role=MessageRole.SYSTEM),
|
||||
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
|
||||
]
|
||||
|
||||
expected_prompt = (
|
||||
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
|
||||
"You are a helpful assistant<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n"
|
||||
"Hello, how are you doing?<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
|
||||
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
||||
|
||||
|
||||
def test_llama3_prompt_style_with_default_system():
|
||||
prompt_style = Llama3PromptStyle()
|
||||
messages = [
|
||||
ChatMessage(content="Hello!", role=MessageRole.USER),
|
||||
]
|
||||
expected = (
|
||||
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
|
||||
f"{prompt_style.DEFAULT_SYSTEM_PROMPT}<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\nHello!<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
assert prompt_style._messages_to_prompt(messages) == expected
|
||||
|
||||
|
||||
def test_llama3_prompt_style_with_assistant_response():
|
||||
prompt_style = Llama3PromptStyle()
|
||||
messages = [
|
||||
ChatMessage(content="You are a helpful assistant", role=MessageRole.SYSTEM),
|
||||
ChatMessage(content="What is the capital of France?", role=MessageRole.USER),
|
||||
ChatMessage(
|
||||
content="The capital of France is Paris.", role=MessageRole.ASSISTANT
|
||||
),
|
||||
]
|
||||
|
||||
expected_prompt = (
|
||||
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
|
||||
"You are a helpful assistant<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n"
|
||||
"What is the capital of France?<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
"The capital of France is Paris.<|eot_id|>"
|
||||
)
|
||||
|
||||
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
||||
|
Reference in New Issue
Block a user