diff --git a/fern/docs/pages/recipes/list-llm.mdx b/fern/docs/pages/recipes/list-llm.mdx
index 2cb80e48..1e53804b 100644
--- a/fern/docs/pages/recipes/list-llm.mdx
+++ b/fern/docs/pages/recipes/list-llm.mdx
@@ -24,7 +24,7 @@ user: {{ user_message }}
assistant: {{ assistant_message }}
```
-And the "`tag`" style looks like this:
+The "`tag`" style looks like this:
```text
<|system|>: {{ system_prompt }}
@@ -32,7 +32,23 @@ And the "`tag`" style looks like this:
<|assistant|>: {{ assistant_message }}
```
-Some LLMs will not understand this prompt style, and will not work (returning nothing).
+The "`mistral`" style looks like this:
+
+```text
+[INST] You are an AI assistant. [/INST][INST] Hello, how are you doing? [/INST]
+```
+
+The "`chatml`" style looks like this:
+```text
+<|im_start|>system
+{{ system_prompt }}<|im_end|>
+<|im_start|>user"
+{{ user_message }}<|im_end|>
+<|im_start|>assistant
+{{ assistant_message }}
+```
+
+Some LLMs will not understand these prompt styles, and will not work (returning nothing).
You can try to change the prompt style to `default` (or `tag`) in the settings, and it will
change the way the messages are formatted to be passed to the LLM.
diff --git a/private_gpt/components/llm/llm_component.py b/private_gpt/components/llm/llm_component.py
index d6a335f8..971cfa3b 100644
--- a/private_gpt/components/llm/llm_component.py
+++ b/private_gpt/components/llm/llm_component.py
@@ -42,7 +42,7 @@ class LLMComponent:
context_window=settings.llm.context_window,
generate_kwargs={},
# All to GPU
- model_kwargs={"n_gpu_layers": -1},
+ model_kwargs={"n_gpu_layers": -1, "offload_kqv": True},
# transform inputs into Llama2 format
messages_to_prompt=prompt_style.messages_to_prompt,
completion_to_prompt=prompt_style.completion_to_prompt,
diff --git a/private_gpt/components/llm/prompt_helper.py b/private_gpt/components/llm/prompt_helper.py
index a8ca60f2..d1df9b81 100644
--- a/private_gpt/components/llm/prompt_helper.py
+++ b/private_gpt/components/llm/prompt_helper.py
@@ -123,8 +123,51 @@ class TagPromptStyle(AbstractPromptStyle):
)
+class MistralPromptStyle(AbstractPromptStyle):
+ def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
+ prompt = ""
+ for message in messages:
+ role = message.role
+ content = message.content or ""
+ if role.lower() == "system":
+ message_from_user = f"[INST] {content.strip()} [/INST]"
+ prompt += message_from_user
+ elif role.lower() == "user":
+ prompt += ""
+ message_from_user = f"[INST] {content.strip()} [/INST]"
+ prompt += message_from_user
+ return prompt
+
+ def _completion_to_prompt(self, completion: str) -> str:
+ return self._messages_to_prompt(
+ [ChatMessage(content=completion, role=MessageRole.USER)]
+ )
+
+
+class ChatMLPromptStyle(AbstractPromptStyle):
+ def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
+ prompt = "<|im_start|>system\n"
+ for message in messages:
+ role = message.role
+ content = message.content or ""
+ if role.lower() == "system":
+ message_from_user = f"{content.strip()}"
+ prompt += message_from_user
+ elif role.lower() == "user":
+ prompt += "<|im_end|>\n<|im_start|>user\n"
+ message_from_user = f"{content.strip()}<|im_end|>\n"
+ prompt += message_from_user
+ prompt += "<|im_start|>assistant\n"
+ return prompt
+
+ def _completion_to_prompt(self, completion: str) -> str:
+ return self._messages_to_prompt(
+ [ChatMessage(content=completion, role=MessageRole.USER)]
+ )
+
+
def get_prompt_style(
- prompt_style: Literal["default", "llama2", "tag"] | None
+ prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None
) -> AbstractPromptStyle:
"""Get the prompt style to use from the given string.
@@ -137,4 +180,8 @@ def get_prompt_style(
return Llama2PromptStyle()
elif prompt_style == "tag":
return TagPromptStyle()
+ elif prompt_style == "mistral":
+ return MistralPromptStyle()
+ elif prompt_style == "chatml":
+ return ChatMLPromptStyle()
raise ValueError(f"Unknown prompt_style='{prompt_style}'")
diff --git a/private_gpt/open_ai/openai_models.py b/private_gpt/open_ai/openai_models.py
index d9171890..dd78daf3 100644
--- a/private_gpt/open_ai/openai_models.py
+++ b/private_gpt/open_ai/openai_models.py
@@ -118,5 +118,5 @@ def to_openai_sse_stream(
yield f"data: {OpenAICompletion.json_from_delta(text=response.delta)}\n\n"
else:
yield f"data: {OpenAICompletion.json_from_delta(text=response, sources=sources)}\n\n"
- yield f"data: {OpenAICompletion.json_from_delta(text=None, finish_reason='stop')}\n\n"
+ yield f"data: {OpenAICompletion.json_from_delta(text='', finish_reason='stop')}\n\n"
yield "data: [DONE]\n\n"
diff --git a/private_gpt/settings/settings.py b/private_gpt/settings/settings.py
index 7c58a762..499ce66d 100644
--- a/private_gpt/settings/settings.py
+++ b/private_gpt/settings/settings.py
@@ -110,13 +110,14 @@ class LocalSettings(BaseModel):
embedding_hf_model_name: str = Field(
description="Name of the HuggingFace model to use for embeddings"
)
- prompt_style: Literal["default", "llama2", "tag"] = Field(
+ prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] = Field(
"llama2",
description=(
"The prompt style to use for the chat engine. "
"If `default` - use the default prompt style from the llama_index. It should look like `role: message`.\n"
"If `llama2` - use the llama2 prompt style from the llama_index. Based on ``, `[INST]` and `<>`.\n"
"If `tag` - use the `tag` prompt style. It should look like `<|role|>: message`. \n"
+ "If `mistral` - use the `mistral prompt style. It shoudl look like [INST] {System Prompt} [/INST][INST] { UserInstructions } [/INST]"
"`llama2` is the historic behaviour. `default` might work better with your custom models."
),
)
diff --git a/private_gpt/settings/settings_loader.py b/private_gpt/settings/settings_loader.py
index b4052db2..2784cdfc 100644
--- a/private_gpt/settings/settings_loader.py
+++ b/private_gpt/settings/settings_loader.py
@@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
_settings_folder = os.environ.get("PGPT_SETTINGS_FOLDER", PROJECT_ROOT_PATH)
# if running in unittest, use the test profile
-_test_profile = ["test"] if "unittest" in sys.modules else []
+_test_profile = ["test"] if "tests.fixtures" in sys.modules else []
active_profiles: list[str] = unique_list(
["default"]
diff --git a/scripts/extract_openapi.py b/scripts/extract_openapi.py
index ba6f138a..15840d91 100644
--- a/scripts/extract_openapi.py
+++ b/scripts/extract_openapi.py
@@ -1,6 +1,7 @@
import argparse
import json
import sys
+
import yaml
from uvicorn.importer import import_from_string
diff --git a/settings-docker.yaml b/settings-docker.yaml
index 49b39618..6b915cca 100644
--- a/settings-docker.yaml
+++ b/settings-docker.yaml
@@ -5,6 +5,9 @@ server:
llm:
mode: ${PGPT_MODE:mock}
+embedding:
+ mode: ${PGPT_MODE:sagemaker}
+
local:
llm_hf_repo_id: ${PGPT_HF_REPO_ID:TheBloke/Mistral-7B-Instruct-v0.1-GGUF}
llm_hf_model_file: ${PGPT_HF_MODEL_FILE:mistral-7b-instruct-v0.1.Q4_K_M.gguf}
@@ -16,4 +19,4 @@ sagemaker:
ui:
enabled: true
- path: /
\ No newline at end of file
+ path: /
diff --git a/settings.yaml b/settings.yaml
index a01780cc..cf444d7d 100644
--- a/settings.yaml
+++ b/settings.yaml
@@ -52,7 +52,7 @@ qdrant:
path: local_data/private_gpt/qdrant
local:
- prompt_style: "llama2"
+ prompt_style: "mistral"
llm_hf_repo_id: TheBloke/Mistral-7B-Instruct-v0.2-GGUF
llm_hf_model_file: mistral-7b-instruct-v0.2.Q4_K_M.gguf
embedding_hf_model_name: BAAI/bge-small-en-v1.5
diff --git a/tests/test_prompt_helper.py b/tests/test_prompt_helper.py
index 48cac0ba..48597698 100644
--- a/tests/test_prompt_helper.py
+++ b/tests/test_prompt_helper.py
@@ -2,8 +2,10 @@ import pytest
from llama_index.llms import ChatMessage, MessageRole
from private_gpt.components.llm.prompt_helper import (
+ ChatMLPromptStyle,
DefaultPromptStyle,
Llama2PromptStyle,
+ MistralPromptStyle,
TagPromptStyle,
get_prompt_style,
)
@@ -15,6 +17,8 @@ from private_gpt.components.llm.prompt_helper import (
("default", DefaultPromptStyle),
("llama2", Llama2PromptStyle),
("tag", TagPromptStyle),
+ ("mistral", MistralPromptStyle),
+ ("chatml", ChatMLPromptStyle),
],
)
def test_get_prompt_style_success(prompt_style, expected_prompt_style):
@@ -62,6 +66,39 @@ def test_tag_prompt_style_format_with_system_prompt():
assert prompt_style.messages_to_prompt(messages) == expected_prompt
+def test_mistral_prompt_style_format():
+ prompt_style = MistralPromptStyle()
+ messages = [
+ ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
+ ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
+ ]
+
+ expected_prompt = (
+ "[INST] You are an AI assistant. [/INST]"
+ "[INST] Hello, how are you doing? [/INST]"
+ )
+
+ assert prompt_style.messages_to_prompt(messages) == expected_prompt
+
+
+def test_chatml_prompt_style_format():
+ prompt_style = ChatMLPromptStyle()
+ messages = [
+ ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
+ ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
+ ]
+
+ expected_prompt = (
+ "<|im_start|>system\n"
+ "You are an AI assistant.<|im_end|>\n"
+ "<|im_start|>user\n"
+ "Hello, how are you doing?<|im_end|>\n"
+ "<|im_start|>assistant\n"
+ )
+
+ assert prompt_style.messages_to_prompt(messages) == expected_prompt
+
+
def test_llama2_prompt_style_format():
prompt_style = Llama2PromptStyle()
messages = [