mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-06-22 05:30:34 +00:00
feat: add mistral + chatml prompts (#1426)
This commit is contained in:
parent
6191bcdbd6
commit
e326126d0d
@ -24,7 +24,7 @@ user: {{ user_message }}
|
|||||||
assistant: {{ assistant_message }}
|
assistant: {{ assistant_message }}
|
||||||
```
|
```
|
||||||
|
|
||||||
And the "`tag`" style looks like this:
|
The "`tag`" style looks like this:
|
||||||
|
|
||||||
```text
|
```text
|
||||||
<|system|>: {{ system_prompt }}
|
<|system|>: {{ system_prompt }}
|
||||||
@ -32,7 +32,23 @@ And the "`tag`" style looks like this:
|
|||||||
<|assistant|>: {{ assistant_message }}
|
<|assistant|>: {{ assistant_message }}
|
||||||
```
|
```
|
||||||
|
|
||||||
Some LLMs will not understand this prompt style, and will not work (returning nothing).
|
The "`mistral`" style looks like this:
|
||||||
|
|
||||||
|
```text
|
||||||
|
<s>[INST] You are an AI assistant. [/INST]</s>[INST] Hello, how are you doing? [/INST]
|
||||||
|
```
|
||||||
|
|
||||||
|
The "`chatml`" style looks like this:
|
||||||
|
```text
|
||||||
|
<|im_start|>system
|
||||||
|
{{ system_prompt }}<|im_end|>
|
||||||
|
<|im_start|>user"
|
||||||
|
{{ user_message }}<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
{{ assistant_message }}
|
||||||
|
```
|
||||||
|
|
||||||
|
Some LLMs will not understand these prompt styles, and will not work (returning nothing).
|
||||||
You can try to change the prompt style to `default` (or `tag`) in the settings, and it will
|
You can try to change the prompt style to `default` (or `tag`) in the settings, and it will
|
||||||
change the way the messages are formatted to be passed to the LLM.
|
change the way the messages are formatted to be passed to the LLM.
|
||||||
|
|
||||||
|
@ -123,8 +123,51 @@ class TagPromptStyle(AbstractPromptStyle):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MistralPromptStyle(AbstractPromptStyle):
|
||||||
|
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||||
|
prompt = "<s>"
|
||||||
|
for message in messages:
|
||||||
|
role = message.role
|
||||||
|
content = message.content or ""
|
||||||
|
if role.lower() == "system":
|
||||||
|
message_from_user = f"[INST] {content.strip()} [/INST]"
|
||||||
|
prompt += message_from_user
|
||||||
|
elif role.lower() == "user":
|
||||||
|
prompt += "</s>"
|
||||||
|
message_from_user = f"[INST] {content.strip()} [/INST]"
|
||||||
|
prompt += message_from_user
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def _completion_to_prompt(self, completion: str) -> str:
|
||||||
|
return self._messages_to_prompt(
|
||||||
|
[ChatMessage(content=completion, role=MessageRole.USER)]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMLPromptStyle(AbstractPromptStyle):
|
||||||
|
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||||
|
prompt = "<|im_start|>system\n"
|
||||||
|
for message in messages:
|
||||||
|
role = message.role
|
||||||
|
content = message.content or ""
|
||||||
|
if role.lower() == "system":
|
||||||
|
message_from_user = f"{content.strip()}"
|
||||||
|
prompt += message_from_user
|
||||||
|
elif role.lower() == "user":
|
||||||
|
prompt += "<|im_end|>\n<|im_start|>user\n"
|
||||||
|
message_from_user = f"{content.strip()}<|im_end|>\n"
|
||||||
|
prompt += message_from_user
|
||||||
|
prompt += "<|im_start|>assistant\n"
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def _completion_to_prompt(self, completion: str) -> str:
|
||||||
|
return self._messages_to_prompt(
|
||||||
|
[ChatMessage(content=completion, role=MessageRole.USER)]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_style(
|
def get_prompt_style(
|
||||||
prompt_style: Literal["default", "llama2", "tag"] | None
|
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None
|
||||||
) -> AbstractPromptStyle:
|
) -> AbstractPromptStyle:
|
||||||
"""Get the prompt style to use from the given string.
|
"""Get the prompt style to use from the given string.
|
||||||
|
|
||||||
@ -137,4 +180,8 @@ def get_prompt_style(
|
|||||||
return Llama2PromptStyle()
|
return Llama2PromptStyle()
|
||||||
elif prompt_style == "tag":
|
elif prompt_style == "tag":
|
||||||
return TagPromptStyle()
|
return TagPromptStyle()
|
||||||
|
elif prompt_style == "mistral":
|
||||||
|
return MistralPromptStyle()
|
||||||
|
elif prompt_style == "chatml":
|
||||||
|
return ChatMLPromptStyle()
|
||||||
raise ValueError(f"Unknown prompt_style='{prompt_style}'")
|
raise ValueError(f"Unknown prompt_style='{prompt_style}'")
|
||||||
|
@ -110,13 +110,14 @@ class LocalSettings(BaseModel):
|
|||||||
embedding_hf_model_name: str = Field(
|
embedding_hf_model_name: str = Field(
|
||||||
description="Name of the HuggingFace model to use for embeddings"
|
description="Name of the HuggingFace model to use for embeddings"
|
||||||
)
|
)
|
||||||
prompt_style: Literal["default", "llama2", "tag"] = Field(
|
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] = Field(
|
||||||
"llama2",
|
"llama2",
|
||||||
description=(
|
description=(
|
||||||
"The prompt style to use for the chat engine. "
|
"The prompt style to use for the chat engine. "
|
||||||
"If `default` - use the default prompt style from the llama_index. It should look like `role: message`.\n"
|
"If `default` - use the default prompt style from the llama_index. It should look like `role: message`.\n"
|
||||||
"If `llama2` - use the llama2 prompt style from the llama_index. Based on `<s>`, `[INST]` and `<<SYS>>`.\n"
|
"If `llama2` - use the llama2 prompt style from the llama_index. Based on `<s>`, `[INST]` and `<<SYS>>`.\n"
|
||||||
"If `tag` - use the `tag` prompt style. It should look like `<|role|>: message`. \n"
|
"If `tag` - use the `tag` prompt style. It should look like `<|role|>: message`. \n"
|
||||||
|
"If `mistral` - use the `mistral prompt style. It shoudl look like <s>[INST] {System Prompt} [/INST]</s>[INST] { UserInstructions } [/INST]"
|
||||||
"`llama2` is the historic behaviour. `default` might work better with your custom models."
|
"`llama2` is the historic behaviour. `default` might work better with your custom models."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from uvicorn.importer import import_from_string
|
from uvicorn.importer import import_from_string
|
||||||
|
|
||||||
|
@ -51,7 +51,7 @@ qdrant:
|
|||||||
path: local_data/private_gpt/qdrant
|
path: local_data/private_gpt/qdrant
|
||||||
|
|
||||||
local:
|
local:
|
||||||
prompt_style: "llama2"
|
prompt_style: "mistral"
|
||||||
llm_hf_repo_id: TheBloke/Mistral-7B-Instruct-v0.2-GGUF
|
llm_hf_repo_id: TheBloke/Mistral-7B-Instruct-v0.2-GGUF
|
||||||
llm_hf_model_file: mistral-7b-instruct-v0.2.Q4_K_M.gguf
|
llm_hf_model_file: mistral-7b-instruct-v0.2.Q4_K_M.gguf
|
||||||
embedding_hf_model_name: BAAI/bge-small-en-v1.5
|
embedding_hf_model_name: BAAI/bge-small-en-v1.5
|
||||||
|
@ -2,8 +2,10 @@ import pytest
|
|||||||
from llama_index.llms import ChatMessage, MessageRole
|
from llama_index.llms import ChatMessage, MessageRole
|
||||||
|
|
||||||
from private_gpt.components.llm.prompt_helper import (
|
from private_gpt.components.llm.prompt_helper import (
|
||||||
|
ChatMLPromptStyle,
|
||||||
DefaultPromptStyle,
|
DefaultPromptStyle,
|
||||||
Llama2PromptStyle,
|
Llama2PromptStyle,
|
||||||
|
MistralPromptStyle,
|
||||||
TagPromptStyle,
|
TagPromptStyle,
|
||||||
get_prompt_style,
|
get_prompt_style,
|
||||||
)
|
)
|
||||||
@ -15,6 +17,8 @@ from private_gpt.components.llm.prompt_helper import (
|
|||||||
("default", DefaultPromptStyle),
|
("default", DefaultPromptStyle),
|
||||||
("llama2", Llama2PromptStyle),
|
("llama2", Llama2PromptStyle),
|
||||||
("tag", TagPromptStyle),
|
("tag", TagPromptStyle),
|
||||||
|
("mistral", MistralPromptStyle),
|
||||||
|
("chatml", ChatMLPromptStyle),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_get_prompt_style_success(prompt_style, expected_prompt_style):
|
def test_get_prompt_style_success(prompt_style, expected_prompt_style):
|
||||||
@ -62,6 +66,39 @@ def test_tag_prompt_style_format_with_system_prompt():
|
|||||||
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_mistral_prompt_style_format():
|
||||||
|
prompt_style = MistralPromptStyle()
|
||||||
|
messages = [
|
||||||
|
ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
|
||||||
|
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
|
||||||
|
]
|
||||||
|
|
||||||
|
expected_prompt = (
|
||||||
|
"<s>[INST] You are an AI assistant. [/INST]</s>"
|
||||||
|
"[INST] Hello, how are you doing? [/INST]"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_chatml_prompt_style_format():
|
||||||
|
prompt_style = ChatMLPromptStyle()
|
||||||
|
messages = [
|
||||||
|
ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
|
||||||
|
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
|
||||||
|
]
|
||||||
|
|
||||||
|
expected_prompt = (
|
||||||
|
"<|im_start|>system\n"
|
||||||
|
"You are an AI assistant.<|im_end|>\n"
|
||||||
|
"<|im_start|>user\n"
|
||||||
|
"Hello, how are you doing?<|im_end|>\n"
|
||||||
|
"<|im_start|>assistant\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
||||||
|
|
||||||
|
|
||||||
def test_llama2_prompt_style_format():
|
def test_llama2_prompt_style_format():
|
||||||
prompt_style = Llama2PromptStyle()
|
prompt_style = Llama2PromptStyle()
|
||||||
messages = [
|
messages = [
|
||||||
|
Loading…
Reference in New Issue
Block a user