feat: add mistral + chatml prompts (#1426)

This commit is contained in:
CognitiveTech 2024-01-16 16:51:14 -05:00 committed by GitHub
parent 6191bcdbd6
commit e326126d0d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 107 additions and 5 deletions

View File

@ -24,7 +24,7 @@ user: {{ user_message }}
assistant: {{ assistant_message }} assistant: {{ assistant_message }}
``` ```
And the "`tag`" style looks like this: The "`tag`" style looks like this:
```text ```text
<|system|>: {{ system_prompt }} <|system|>: {{ system_prompt }}
@ -32,7 +32,23 @@ And the "`tag`" style looks like this:
<|assistant|>: {{ assistant_message }} <|assistant|>: {{ assistant_message }}
``` ```
Some LLMs will not understand this prompt style, and will not work (returning nothing). The "`mistral`" style looks like this:
```text
<s>[INST] You are an AI assistant. [/INST]</s>[INST] Hello, how are you doing? [/INST]
```
The "`chatml`" style looks like this:
```text
<|im_start|>system
{{ system_prompt }}<|im_end|>
<|im_start|>user"
{{ user_message }}<|im_end|>
<|im_start|>assistant
{{ assistant_message }}
```
Some LLMs will not understand these prompt styles, and will not work (returning nothing).
You can try to change the prompt style to `default` (or `tag`) in the settings, and it will You can try to change the prompt style to `default` (or `tag`) in the settings, and it will
change the way the messages are formatted to be passed to the LLM. change the way the messages are formatted to be passed to the LLM.

View File

@ -123,8 +123,51 @@ class TagPromptStyle(AbstractPromptStyle):
) )
class MistralPromptStyle(AbstractPromptStyle):
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
prompt = "<s>"
for message in messages:
role = message.role
content = message.content or ""
if role.lower() == "system":
message_from_user = f"[INST] {content.strip()} [/INST]"
prompt += message_from_user
elif role.lower() == "user":
prompt += "</s>"
message_from_user = f"[INST] {content.strip()} [/INST]"
prompt += message_from_user
return prompt
def _completion_to_prompt(self, completion: str) -> str:
return self._messages_to_prompt(
[ChatMessage(content=completion, role=MessageRole.USER)]
)
class ChatMLPromptStyle(AbstractPromptStyle):
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
prompt = "<|im_start|>system\n"
for message in messages:
role = message.role
content = message.content or ""
if role.lower() == "system":
message_from_user = f"{content.strip()}"
prompt += message_from_user
elif role.lower() == "user":
prompt += "<|im_end|>\n<|im_start|>user\n"
message_from_user = f"{content.strip()}<|im_end|>\n"
prompt += message_from_user
prompt += "<|im_start|>assistant\n"
return prompt
def _completion_to_prompt(self, completion: str) -> str:
return self._messages_to_prompt(
[ChatMessage(content=completion, role=MessageRole.USER)]
)
def get_prompt_style( def get_prompt_style(
prompt_style: Literal["default", "llama2", "tag"] | None prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None
) -> AbstractPromptStyle: ) -> AbstractPromptStyle:
"""Get the prompt style to use from the given string. """Get the prompt style to use from the given string.
@ -137,4 +180,8 @@ def get_prompt_style(
return Llama2PromptStyle() return Llama2PromptStyle()
elif prompt_style == "tag": elif prompt_style == "tag":
return TagPromptStyle() return TagPromptStyle()
elif prompt_style == "mistral":
return MistralPromptStyle()
elif prompt_style == "chatml":
return ChatMLPromptStyle()
raise ValueError(f"Unknown prompt_style='{prompt_style}'") raise ValueError(f"Unknown prompt_style='{prompt_style}'")

View File

@ -110,13 +110,14 @@ class LocalSettings(BaseModel):
embedding_hf_model_name: str = Field( embedding_hf_model_name: str = Field(
description="Name of the HuggingFace model to use for embeddings" description="Name of the HuggingFace model to use for embeddings"
) )
prompt_style: Literal["default", "llama2", "tag"] = Field( prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] = Field(
"llama2", "llama2",
description=( description=(
"The prompt style to use for the chat engine. " "The prompt style to use for the chat engine. "
"If `default` - use the default prompt style from the llama_index. It should look like `role: message`.\n" "If `default` - use the default prompt style from the llama_index. It should look like `role: message`.\n"
"If `llama2` - use the llama2 prompt style from the llama_index. Based on `<s>`, `[INST]` and `<<SYS>>`.\n" "If `llama2` - use the llama2 prompt style from the llama_index. Based on `<s>`, `[INST]` and `<<SYS>>`.\n"
"If `tag` - use the `tag` prompt style. It should look like `<|role|>: message`. \n" "If `tag` - use the `tag` prompt style. It should look like `<|role|>: message`. \n"
"If `mistral` - use the `mistral prompt style. It shoudl look like <s>[INST] {System Prompt} [/INST]</s>[INST] { UserInstructions } [/INST]"
"`llama2` is the historic behaviour. `default` might work better with your custom models." "`llama2` is the historic behaviour. `default` might work better with your custom models."
), ),
) )

View File

@ -1,6 +1,7 @@
import argparse import argparse
import json import json
import sys import sys
import yaml import yaml
from uvicorn.importer import import_from_string from uvicorn.importer import import_from_string

View File

@ -51,7 +51,7 @@ qdrant:
path: local_data/private_gpt/qdrant path: local_data/private_gpt/qdrant
local: local:
prompt_style: "llama2" prompt_style: "mistral"
llm_hf_repo_id: TheBloke/Mistral-7B-Instruct-v0.2-GGUF llm_hf_repo_id: TheBloke/Mistral-7B-Instruct-v0.2-GGUF
llm_hf_model_file: mistral-7b-instruct-v0.2.Q4_K_M.gguf llm_hf_model_file: mistral-7b-instruct-v0.2.Q4_K_M.gguf
embedding_hf_model_name: BAAI/bge-small-en-v1.5 embedding_hf_model_name: BAAI/bge-small-en-v1.5

View File

@ -2,8 +2,10 @@ import pytest
from llama_index.llms import ChatMessage, MessageRole from llama_index.llms import ChatMessage, MessageRole
from private_gpt.components.llm.prompt_helper import ( from private_gpt.components.llm.prompt_helper import (
ChatMLPromptStyle,
DefaultPromptStyle, DefaultPromptStyle,
Llama2PromptStyle, Llama2PromptStyle,
MistralPromptStyle,
TagPromptStyle, TagPromptStyle,
get_prompt_style, get_prompt_style,
) )
@ -15,6 +17,8 @@ from private_gpt.components.llm.prompt_helper import (
("default", DefaultPromptStyle), ("default", DefaultPromptStyle),
("llama2", Llama2PromptStyle), ("llama2", Llama2PromptStyle),
("tag", TagPromptStyle), ("tag", TagPromptStyle),
("mistral", MistralPromptStyle),
("chatml", ChatMLPromptStyle),
], ],
) )
def test_get_prompt_style_success(prompt_style, expected_prompt_style): def test_get_prompt_style_success(prompt_style, expected_prompt_style):
@ -62,6 +66,39 @@ def test_tag_prompt_style_format_with_system_prompt():
assert prompt_style.messages_to_prompt(messages) == expected_prompt assert prompt_style.messages_to_prompt(messages) == expected_prompt
def test_mistral_prompt_style_format():
prompt_style = MistralPromptStyle()
messages = [
ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
]
expected_prompt = (
"<s>[INST] You are an AI assistant. [/INST]</s>"
"[INST] Hello, how are you doing? [/INST]"
)
assert prompt_style.messages_to_prompt(messages) == expected_prompt
def test_chatml_prompt_style_format():
prompt_style = ChatMLPromptStyle()
messages = [
ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
]
expected_prompt = (
"<|im_start|>system\n"
"You are an AI assistant.<|im_end|>\n"
"<|im_start|>user\n"
"Hello, how are you doing?<|im_end|>\n"
"<|im_start|>assistant\n"
)
assert prompt_style.messages_to_prompt(messages) == expected_prompt
def test_llama2_prompt_style_format(): def test_llama2_prompt_style_format():
prompt_style = Llama2PromptStyle() prompt_style = Llama2PromptStyle()
messages = [ messages = [