diff --git a/private_gpt/components/llm/llm_component.py b/private_gpt/components/llm/llm_component.py
index cfe3a737..e03ac266 100644
--- a/private_gpt/components/llm/llm_component.py
+++ b/private_gpt/components/llm/llm_component.py
@@ -4,7 +4,6 @@ from injector import inject, singleton
from llama_index.llms import MockLLM
from llama_index.llms.base import LLM
-from private_gpt.components.llm.prompt_helper import get_prompt_style
from private_gpt.paths import models_path
from private_gpt.settings.settings import Settings
@@ -23,8 +22,11 @@ class LLMComponent:
case "local":
from llama_index.llms import LlamaCPP
- prompt_style_cls = get_prompt_style(settings.local.prompt_style)
- prompt_style = prompt_style_cls(
+ from private_gpt.components.llm.prompt.prompt_helper import get_prompt_style
+
+ prompt_style = get_prompt_style(
+ prompt_style=settings.local.prompt_style,
+ template_name=settings.local.template_name,
default_system_prompt=settings.local.default_system_prompt
)
@@ -43,6 +45,7 @@ class LLMComponent:
completion_to_prompt=prompt_style.completion_to_prompt,
verbose=True,
)
+ # prompt_style.improve_prompt_format(llm=cast(LlamaCPP, self.llm))
case "sagemaker":
from private_gpt.components.llm.custom.sagemaker import SagemakerLLM
diff --git a/private_gpt/components/llm/prompt/__init__.py b/private_gpt/components/llm/prompt/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/private_gpt/components/llm/prompt/prompt_helper.py b/private_gpt/components/llm/prompt/prompt_helper.py
new file mode 100644
index 00000000..ace96c40
--- /dev/null
+++ b/private_gpt/components/llm/prompt/prompt_helper.py
@@ -0,0 +1,401 @@
+"""
+Helper to get your llama_index messages correctly serialized into a prompt.
+
+This set of classes and functions is used to format a series of
+llama_index ChatMessage into a prompt (a unique string) that will be passed
+as is to the LLM. The LLM will then use this prompt to generate a completion.
+
+There are **MANY** formats for prompts; usually, each model has its own format.
+Models posted on HuggingFace usually have a description of the format they use.
+The original models, that are shipped through `transformers`, have their
+format defined in the file `tokenizer_config.json` in the model's directory.
+The prompt format are usually defined as a Jinja template (with some custom
+Jinja token definitions). These prompt templates are usable using
+the `transformers.AutoTokenizer`, as described in
+https://huggingface.co/docs/transformers/main/chat_templating
+
+
+
+Examples of `tokenizer_config.json` files:
+https://huggingface.co/bofenghuang/vigogne-2-7b-chat/blob/main/tokenizer_config.json
+https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
+https://huggingface.co/HuggingFaceH4/zephyr-7b-beta/blob/main/tokenizer_config.json
+
+The format of the prompt is important, as if the wrong one is used, it
+will lead to "hallucinations" and other completions that are not relevant.
+"""
+
+import abc
+import logging
+from collections.abc import Sequence
+from pathlib import Path
+from typing import Any, cast
+
+from jinja2 import FileSystemLoader
+from jinja2.exceptions import TemplateError
+from jinja2.sandbox import ImmutableSandboxedEnvironment
+from llama_cpp import llama_types, Llama
+from llama_cpp import llama_chat_format
+from llama_index.llms import ChatMessage, MessageRole, LlamaCPP
+from llama_index.llms.llama_utils import (
+ DEFAULT_SYSTEM_PROMPT,
+ completion_to_prompt,
+ messages_to_prompt,
+)
+
+from private_gpt.constants import PROJECT_ROOT_PATH
+
+logger = logging.getLogger(__name__)
+
+THIS_DIRECTORY_RELATIVE = Path(__file__).parent.relative_to(PROJECT_ROOT_PATH)
+
+
+_LLAMA_CPP_PYTHON_CHAT_FORMAT = {
+ "llama-2": llama_chat_format.format_llama2,
+ "alpaca": llama_chat_format.format_alpaca,
+ "vicuna": llama_chat_format.format,
+ "oasst_llama": llama_chat_format.format_oasst_llama,
+ "baichuan-2": llama_chat_format.format_baichuan2,
+ "baichuan": llama_chat_format.format_baichuan,
+ "openbuddy": llama_chat_format.format_openbuddy,
+ "redpajama-incite": llama_chat_format.format_redpajama_incite,
+ "snoozy": llama_chat_format.format_snoozy,
+ "phind": llama_chat_format.format_phind,
+ "intel": llama_chat_format.format_intel,
+ "open-orca": llama_chat_format.format_open_orca,
+ "mistrallite": llama_chat_format.format_mistrallite,
+ "zephyr": llama_chat_format.format_zephyr,
+ "chatml": llama_chat_format.format_chatml,
+ "openchat": llama_chat_format.format_openchat,
+}
+
+
+# FIXME partial support
+def llama_index_to_llama_cpp_messages(
+ messages: Sequence[ChatMessage],
+) -> list[llama_types.ChatCompletionRequestMessage]:
+ """Convert messages from llama_index to llama_cpp format.
+
+ Convert a list of llama_index ChatMessage to a
+ list of llama_cpp ChatCompletionRequestMessage.
+ """
+ llama_cpp_messages: list[llama_types.ChatCompletionRequestMessage] = []
+ for msg in messages:
+ if msg.role == MessageRole.SYSTEM:
+ l_msg = llama_types.ChatCompletionRequestSystemMessage(
+ content=msg.content, role=msg.role.value
+ )
+ elif msg.role == MessageRole.USER:
+ # FIXME partial support
+ l_msg = llama_types.ChatCompletionRequestUserMessage(
+ content=msg.content, role=msg.role.value
+ )
+ elif msg.role == MessageRole.ASSISTANT:
+ # FIXME partial support
+ l_msg = llama_types.ChatCompletionRequestAssistantMessage(
+ content=msg.content, role=msg.role.value
+ )
+ elif msg.role == MessageRole.TOOL:
+ # FIXME partial support
+ l_msg = llama_types.ChatCompletionRequestToolMessage(
+ content=msg.content, role=msg.role.value, tool_call_id=""
+ )
+ elif msg.role == MessageRole.FUNCTION:
+ # FIXME partial support
+ l_msg = llama_types.ChatCompletionRequestFunctionMessage(
+ content=msg.content, role=msg.role.value, name=""
+ )
+ else:
+ raise ValueError(f"Unknown role='{msg.role}'")
+ llama_cpp_messages.append(l_msg)
+ return llama_cpp_messages
+
+
+def _get_llama_cpp_chat_format(name: str) -> llama_chat_format.ChatFormatter:
+ try:
+ return _LLAMA_CPP_PYTHON_CHAT_FORMAT[name]
+ except KeyError:
+ raise ValueError(f"Unknown llama_cpp_python prompt style '{name}'")
+
+
+class AbstractPromptStyle(abc.ABC):
+ """Abstract class for prompt styles.
+
+ This class is used to format a series of messages into a prompt that can be
+ understood by the models. A series of messages represents the interaction(s)
+ between a user and an assistant. This series of messages can be considered as a
+ session between a user X and an assistant Y.This session holds, through the
+ messages, the state of the conversation. This session, to be understood by the
+ model, needs to be formatted into a prompt (i.e. a string that the models
+ can understand). Prompts can be formatted in different ways,
+ depending on the model.
+
+ The implementations of this class represent the different ways to format a
+ series of messages into a prompt.
+ """
+
+ @abc.abstractmethod
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ logger.debug("Initializing prompt_style=%s", self.__class__.__name__)
+ self.bos_token = ""
+ self.eos_token = ""
+ self.nl_token = "\n"
+
+ @abc.abstractmethod
+ def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
+ pass
+
+ @abc.abstractmethod
+ def _completion_to_prompt(self, completion: str) -> str:
+ pass
+
+ def messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
+ logger.debug("Formatting messages='%s' to prompt", messages)
+ prompt = self._messages_to_prompt(messages)
+ logger.debug("Got for messages='%s' the prompt='%s'", messages, prompt)
+ return prompt
+
+ def completion_to_prompt(self, completion: str) -> str:
+ logger.debug("Formatting completion='%s' to prompt", completion)
+ prompt = self._completion_to_prompt(completion)
+ logger.debug("Got for completion='%s' the prompt='%s'", completion, prompt)
+ return prompt
+
+ def improve_prompt_format(self, llm: LlamaCPP) -> None:
+ """Improve the prompt format of the given LLM.
+
+ Use the given metadata in the LLM to improve the prompt format.
+ """
+ # FIXME: below we are getting IDs (1,2,13) from llama.cpp, and not actual strings
+ llama_cpp_llm = cast(Llama, llm._model)
+ self.bos_token = llama_cpp_llm.token_bos()
+ self.eos_token = llama_cpp_llm.token_eos()
+ self.nl_token = llama_cpp_llm.token_nl()
+ print([self.bos_token, self.eos_token, self.nl_token])
+ # (1,2,13) are the IDs of the tokens
+
+
+class AbstractPromptStyleWithSystemPrompt(AbstractPromptStyle, abc.ABC):
+ _DEFAULT_SYSTEM_PROMPT = DEFAULT_SYSTEM_PROMPT
+
+ def __init__(self, default_system_prompt: str | None) -> None:
+ super().__init__()
+ logger.debug("Got default_system_prompt='%s'", default_system_prompt)
+ self.default_system_prompt = default_system_prompt
+
+ def _add_missing_system_prompt(self, messages: Sequence[ChatMessage]) -> list[ChatMessage]:
+ if messages[0].role != MessageRole.SYSTEM:
+ logger.debug(
+ "Adding system_promt='%s' to the given messages as there are none given in the session",
+ self.default_system_prompt,
+ )
+ messages = [
+ ChatMessage(content=self.default_system_prompt, role=MessageRole.SYSTEM),
+ *messages,
+ ]
+ return messages
+
+
+class DefaultPromptStyle(AbstractPromptStyle):
+ """Default prompt style that uses the defaults from llama_utils.
+
+ It basically passes None to the LLM, indicating it should use
+ the default functions.
+ """
+
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ super().__init__(*args, **kwargs)
+
+ # Hacky way to override the functions
+ # Override the functions to be None, and pass None to the LLM.
+ self.messages_to_prompt = None # type: ignore[method-assign, assignment]
+ self.completion_to_prompt = None # type: ignore[method-assign, assignment]
+
+ def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
+ """Dummy implementation."""
+ return ""
+
+ def _completion_to_prompt(self, completion: str) -> str:
+ """Dummy implementation."""
+ return ""
+
+
+class LlamaIndexPromptStyle(AbstractPromptStyleWithSystemPrompt):
+ """Simple prompt style that just uses the default llama_utils functions.
+
+ It transforms the sequence of messages into a prompt that should look like:
+ ```text
+ [INST] <> your system prompt here. <>
+
+ user message here [/INST] assistant (model) response here
+ ```
+ """
+
+ def __init__(self, default_system_prompt: str | None = None) -> None:
+ # If no system prompt is given, the default one of the implementation is used.
+ super().__init__(default_system_prompt=default_system_prompt)
+
+ def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
+ return messages_to_prompt(messages, self.default_system_prompt)
+
+ def _completion_to_prompt(self, completion: str) -> str:
+ return completion_to_prompt(completion, self.default_system_prompt)
+
+
+class VigognePromptStyle(AbstractPromptStyleWithSystemPrompt):
+ """Tag prompt style (used by Vigogne) that uses the prompt style `<|ROLE|>`.
+
+ It transforms the sequence of messages into a prompt that should look like:
+ ```text
+ <|system|>: your system prompt here.
+ <|user|>: user message here
+ (possibly with context and question)
+ <|assistant|>: assistant (model) response here.
+ ```
+
+ FIXME: should we add surrounding `` and `` tags, like in llama2?
+ """
+
+ def __init__(self, default_system_prompt: str | None = None, add_generation_prompt: bool = True) -> None:
+ # We have to define a default system prompt here as the LLM will not
+ # use the default llama_utils functions.
+ default_system_prompt = default_system_prompt or self._DEFAULT_SYSTEM_PROMPT
+ super().__init__(default_system_prompt)
+ self.system_prompt: str = default_system_prompt
+ self.add_generation_prompt = add_generation_prompt
+
+ def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
+ messages = self._add_missing_system_prompt(messages)
+ return self._format_messages_to_prompt(messages)
+
+ def _completion_to_prompt(self, completion: str) -> str:
+ messages = [ChatMessage(content=completion, role=MessageRole.USER)]
+ return self._format_messages_to_prompt(messages)
+
+ def _format_messages_to_prompt(self, messages: list[ChatMessage]) -> str:
+ # TODO add BOS and EOS TOKEN !!!!! (c.f. jinja template)
+ """Format message to prompt with `<|ROLE|>: MSG` style."""
+ assert messages[0].role == MessageRole.SYSTEM
+ prompt = ""
+ # TODO enclose the interaction between self.token_bos and self.token_eos
+ for message in messages:
+ role = message.role
+ content = message.content or ""
+ message_from_user = f"<|{role.lower()}|>: {content.strip()}"
+ message_from_user += self.nl_token
+ prompt += message_from_user
+ if self.add_generation_prompt:
+ # we are missing the last <|assistant|> tag that will trigger a completion
+ prompt += "<|assistant|>: "
+ return prompt
+
+
+class LlamaCppPromptStyle(AbstractPromptStyleWithSystemPrompt):
+ def __init__(self, prompt_style: str, default_system_prompt: str | None = None) -> None:
+ """Wrapper for llama_cpp_python defined prompt format.
+ :param prompt_style:
+ :param default_system_prompt: Used if no system prompt is given in the messages.
+ """
+ default_system_prompt = default_system_prompt or self._DEFAULT_SYSTEM_PROMPT
+ super().__init__(default_system_prompt)
+
+ self.prompt_style = prompt_style
+ if self.prompt_style is None:
+ return
+
+ self._llama_cpp_formatter = _get_llama_cpp_chat_format(self.prompt_style)
+ self.messages_to_prompt = self._messages_to_prompt
+ self.completion_to_prompt = self._completion_to_prompt
+
+ def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
+ messages = self._add_missing_system_prompt(messages)
+ return self._llama_cpp_formatter(
+ messages=llama_index_to_llama_cpp_messages(messages)
+ ).prompt
+
+ def _completion_to_prompt(self, completion: str) -> str:
+ messages = self._add_missing_system_prompt([ChatMessage(content=completion, role=MessageRole.USER)])
+ return self._llama_cpp_formatter(
+ messages=llama_index_to_llama_cpp_messages(messages)
+ ).prompt
+
+
+class TemplatePromptStyle(AbstractPromptStyleWithSystemPrompt):
+ def __init__(self, template_name: str, add_generation_prompt: bool = True,
+ default_system_prompt: str | None = None) -> None:
+ """Prompt format using a Jinja template.
+
+ :param template_name: the filename of the template to use, must be in
+ the `./template/` directory.
+ :param default_system_prompt: Used if no system prompt is
+ given in the messages.
+ """
+ default_system_prompt = default_system_prompt or DEFAULT_SYSTEM_PROMPT
+ super().__init__(default_system_prompt)
+
+ self._add_generation_prompt = add_generation_prompt
+
+ def raise_exception(message):
+ raise TemplateError(message)
+
+ self._jinja_fs_loader = FileSystemLoader(searchpath=THIS_DIRECTORY_RELATIVE / "template")
+ self._jinja_env = ImmutableSandboxedEnvironment(loader=self._jinja_fs_loader, trim_blocks=True,
+ lstrip_blocks=True)
+ self._jinja_env.globals["raise_exception"] = raise_exception
+
+ self.template = self._jinja_env.get_template(template_name)
+
+ @property
+ def _extra_kwargs_render(self) -> dict[str, Any]:
+ return {
+ "eos_token": self.eos_token,
+ "bos_token": self.bos_token,
+ "nl_token": self.nl_token,
+ }
+
+ @staticmethod
+ def _j_raise_exception(x: str) -> None:
+ """Helper method to let Jinja template raise exceptions."""
+ raise RuntimeError(x)
+
+ def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
+ messages = self._add_missing_system_prompt(messages)
+ msgs = [{"role": msg.role.value, "content": msg.content} for msg in messages]
+ return self.template.render(
+ messages=msgs,
+ add_generation_prompt=self._add_generation_prompt,
+ **self._extra_kwargs_render,
+ )
+
+ def _completion_to_prompt(self, completion: str) -> str:
+ messages = self._add_missing_system_prompt([
+ ChatMessage(content=completion, role=MessageRole.USER),
+ ])
+ return self._messages_to_prompt(messages)
+
+
+# TODO Maybe implement an auto-prompt style?
+
+
+# Pass all the arguments at once
+def get_prompt_style(
+ prompt_style: str | None,
+ **kwargs: Any,
+) -> AbstractPromptStyle:
+ """Get the prompt style to use from the given string.
+
+ :param prompt_style: The prompt style to use.
+ :return: The prompt style to use.
+ """
+ if prompt_style is None or prompt_style == "default":
+ return DefaultPromptStyle(**kwargs)
+ if prompt_style.startswith("llama_cpp."):
+ prompt_style = prompt_style[len("llama_cpp."):]
+ return LlamaCppPromptStyle(prompt_style, **kwargs)
+ elif prompt_style == "llama2":
+ return LlamaIndexPromptStyle(**kwargs)
+ elif prompt_style == "vigogne":
+ return VigognePromptStyle(**kwargs)
+ elif prompt_style == "template":
+ return TemplatePromptStyle(**kwargs)
+ raise ValueError(f"Unknown prompt_style='{prompt_style}'")
diff --git a/private_gpt/components/llm/prompt/template/Mistral-7B-Instruct-v0.1.jinja b/private_gpt/components/llm/prompt/template/Mistral-7B-Instruct-v0.1.jinja
new file mode 100644
index 00000000..2630528e
--- /dev/null
+++ b/private_gpt/components/llm/prompt/template/Mistral-7B-Instruct-v0.1.jinja
@@ -0,0 +1,2 @@
+{# This template is coming from: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json #}
+{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}
\ No newline at end of file
diff --git a/private_gpt/components/llm/prompt/template/vigogne-2-7b-chat.jinja b/private_gpt/components/llm/prompt/template/vigogne-2-7b-chat.jinja
new file mode 100644
index 00000000..8b8378ba
--- /dev/null
+++ b/private_gpt/components/llm/prompt/template/vigogne-2-7b-chat.jinja
@@ -0,0 +1,2 @@
+{# This template is coming from: https://huggingface.co/bofenghuang/vigogne-2-7b-chat/blob/main/tokenizer_config.json #}
+{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|system|>: ' + system_message + '\\n' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '<|user|>: ' + message['content'].strip() + '\\n' }}{% elif message['role'] == 'assistant' %}{{ '<|assistant|>: ' + message['content'].strip() + eos_token + '\\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>:' }}{% endif %}
\ No newline at end of file
diff --git a/private_gpt/components/llm/prompt/template/zephyr-7b-beta.jinja b/private_gpt/components/llm/prompt/template/zephyr-7b-beta.jinja
new file mode 100644
index 00000000..4160437e
--- /dev/null
+++ b/private_gpt/components/llm/prompt/template/zephyr-7b-beta.jinja
@@ -0,0 +1,2 @@
+{# This template is coming from: https://huggingface.co/HuggingFaceH4/zephyr-7b-beta/blob/main/tokenizer_config.json #}
+{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}
\ No newline at end of file
diff --git a/private_gpt/components/llm/prompt_helper.py b/private_gpt/components/llm/prompt_helper.py
deleted file mode 100644
index e47b3fb9..00000000
--- a/private_gpt/components/llm/prompt_helper.py
+++ /dev/null
@@ -1,179 +0,0 @@
-import abc
-import logging
-from collections.abc import Sequence
-from typing import Any, Literal
-
-from llama_index.llms import ChatMessage, MessageRole
-from llama_index.llms.llama_utils import (
- DEFAULT_SYSTEM_PROMPT,
- completion_to_prompt,
- messages_to_prompt,
-)
-
-logger = logging.getLogger(__name__)
-
-
-class AbstractPromptStyle(abc.ABC):
- """Abstract class for prompt styles.
-
- This class is used to format a series of messages into a prompt that can be
- understood by the models. A series of messages represents the interaction(s)
- between a user and an assistant. This series of messages can be considered as a
- session between a user X and an assistant Y.This session holds, through the
- messages, the state of the conversation. This session, to be understood by the
- model, needs to be formatted into a prompt (i.e. a string that the models
- can understand). Prompts can be formatted in different ways,
- depending on the model.
-
- The implementations of this class represent the different ways to format a
- series of messages into a prompt.
- """
-
- @abc.abstractmethod
- def __init__(self, *args: Any, **kwargs: Any) -> None:
- logger.debug("Initializing prompt_style=%s", self.__class__.__name__)
-
- @abc.abstractmethod
- def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
- pass
-
- @abc.abstractmethod
- def _completion_to_prompt(self, completion: str) -> str:
- pass
-
- def messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
- prompt = self._messages_to_prompt(messages)
- logger.debug("Got for messages='%s' the prompt='%s'", messages, prompt)
- return prompt
-
- def completion_to_prompt(self, completion: str) -> str:
- prompt = self._completion_to_prompt(completion)
- logger.debug("Got for completion='%s' the prompt='%s'", completion, prompt)
- return prompt
-
-
-class AbstractPromptStyleWithSystemPrompt(AbstractPromptStyle, abc.ABC):
- _DEFAULT_SYSTEM_PROMPT = DEFAULT_SYSTEM_PROMPT
-
- def __init__(self, default_system_prompt: str | None) -> None:
- super().__init__()
- logger.debug("Got default_system_prompt='%s'", default_system_prompt)
- self.default_system_prompt = default_system_prompt
-
-
-class DefaultPromptStyle(AbstractPromptStyle):
- """Default prompt style that uses the defaults from llama_utils.
-
- It basically passes None to the LLM, indicating it should use
- the default functions.
- """
-
- def __init__(self, *args: Any, **kwargs: Any) -> None:
- super().__init__(*args, **kwargs)
-
- # Hacky way to override the functions
- # Override the functions to be None, and pass None to the LLM.
- self.messages_to_prompt = None # type: ignore[method-assign, assignment]
- self.completion_to_prompt = None # type: ignore[method-assign, assignment]
-
- def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
- return ""
-
- def _completion_to_prompt(self, completion: str) -> str:
- return ""
-
-
-class Llama2PromptStyle(AbstractPromptStyleWithSystemPrompt):
- """Simple prompt style that just uses the default llama_utils functions.
-
- It transforms the sequence of messages into a prompt that should look like:
- ```text
- [INST] <> your system prompt here. <>
-
- user message here [/INST] assistant (model) response here
- ```
- """
-
- def __init__(self, default_system_prompt: str | None = None) -> None:
- # If no system prompt is given, the default one of the implementation is used.
- super().__init__(default_system_prompt=default_system_prompt)
-
- def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
- return messages_to_prompt(messages, self.default_system_prompt)
-
- def _completion_to_prompt(self, completion: str) -> str:
- return completion_to_prompt(completion, self.default_system_prompt)
-
-
-class TagPromptStyle(AbstractPromptStyleWithSystemPrompt):
- """Tag prompt style (used by Vigogne) that uses the prompt style `<|ROLE|>`.
-
- It transforms the sequence of messages into a prompt that should look like:
- ```text
- <|system|>: your system prompt here.
- <|user|>: user message here
- (possibly with context and question)
- <|assistant|>: assistant (model) response here.
- ```
-
- FIXME: should we add surrounding `` and `` tags, like in llama2?
- """
-
- def __init__(self, default_system_prompt: str | None = None) -> None:
- # We have to define a default system prompt here as the LLM will not
- # use the default llama_utils functions.
- default_system_prompt = default_system_prompt or self._DEFAULT_SYSTEM_PROMPT
- super().__init__(default_system_prompt)
- self.system_prompt: str = default_system_prompt
-
- def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
- messages = list(messages)
- if messages[0].role != MessageRole.SYSTEM:
- logger.info(
- "Adding system_promt='%s' to the given messages as there are none given in the session",
- self.system_prompt,
- )
- messages = [
- ChatMessage(content=self.system_prompt, role=MessageRole.SYSTEM),
- *messages,
- ]
- return self._format_messages_to_prompt(messages)
-
- def _completion_to_prompt(self, completion: str) -> str:
- return (
- f"<|system|>: {self.system_prompt.strip()}\n"
- f"<|user|>: {completion.strip()}\n"
- "<|assistant|>: "
- )
-
- @staticmethod
- def _format_messages_to_prompt(messages: list[ChatMessage]) -> str:
- """Format message to prompt with `<|ROLE|>: MSG` style."""
- assert messages[0].role == MessageRole.SYSTEM
- prompt = ""
- for message in messages:
- role = message.role
- content = message.content or ""
- message_from_user = f"<|{role.lower()}|>: {content.strip()}"
- message_from_user += "\n"
- prompt += message_from_user
- # we are missing the last <|assistant|> tag that will trigger a completion
- prompt += "<|assistant|>: "
- return prompt
-
-
-def get_prompt_style(
- prompt_style: Literal["default", "llama2", "tag"] | None
-) -> type[AbstractPromptStyle]:
- """Get the prompt style to use from the given string.
-
- :param prompt_style: The prompt style to use.
- :return: The prompt style to use.
- """
- if prompt_style is None or prompt_style == "default":
- return DefaultPromptStyle
- elif prompt_style == "llama2":
- return Llama2PromptStyle
- elif prompt_style == "tag":
- return TagPromptStyle
- raise ValueError(f"Unknown prompt_style='{prompt_style}'")
diff --git a/private_gpt/settings/settings.py b/private_gpt/settings/settings.py
index 125396c3..4a734145 100644
--- a/private_gpt/settings/settings.py
+++ b/private_gpt/settings/settings.py
@@ -98,13 +98,33 @@ class LocalSettings(BaseModel):
embedding_hf_model_name: str = Field(
description="Name of the HuggingFace model to use for embeddings"
)
- prompt_style: Literal["default", "llama2", "tag"] = Field(
+ prompt_style: Literal[
+ "llama_cpp.llama-2",
+ "llama_cpp.alpaca",
+ "llama_cpp.vicuna",
+ "llama_cpp.oasst_llama",
+ "llama_cpp.baichuan-2",
+ "llama_cpp.baichuan",
+ "llama_cpp.openbuddy",
+ "llama_cpp.redpajama-incite",
+ "llama_cpp.snoozy",
+ "llama_cpp.phind",
+ "llama_cpp.intel",
+ "llama_cpp.open-orca",
+ "llama_cpp.mistrallite",
+ "llama_cpp.zephyr",
+ "llama_cpp.chatml",
+ "llama_cpp.openchat",
"llama2",
+ "vigogne",
+ "template",
+ ] | None = Field(
+ None,
description=(
"The prompt style to use for the chat engine. "
- "If `default` - use the default prompt style from the llama_index. It should look like `role: message`.\n"
+ "If None is given - use the default prompt style from the llama_index. It should look like `role: message`.\n"
"If `llama2` - use the llama2 prompt style from the llama_index. Based on ``, `[INST]` and `<>`.\n"
- "If `tag` - use the `tag` prompt style. It should look like `<|role|>: message`. \n"
+ "If `llama_cpp.` - use the `` prompt style, implemented by `llama-cpp-python`. \n"
"`llama2` is the historic behaviour. `default` might work better with your custom models."
),
)
@@ -118,6 +138,13 @@ class LocalSettings(BaseModel):
),
)
+ template_name: str | None = Field(
+ None,
+ description=(
+ "The name of the template to use for the chat engine, if the `prompt_style` is `template`."
+ ),
+ )
+
class EmbeddingSettings(BaseModel):
mode: Literal["local", "openai", "sagemaker", "mock"]
diff --git a/private_gpt/ui/ui.py b/private_gpt/ui/ui.py
index eeddb0fb..540235f7 100644
--- a/private_gpt/ui/ui.py
+++ b/private_gpt/ui/ui.py
@@ -9,7 +9,7 @@ import gradio as gr # type: ignore
from fastapi import FastAPI
from gradio.themes.utils.colors import slate # type: ignore
from injector import inject, singleton
-from llama_index.llms import ChatMessage, ChatResponse, MessageRole
+from llama_index.llms import ChatMessage, MessageRole
from pydantic import BaseModel
from private_gpt.constants import PROJECT_ROOT_PATH
@@ -55,6 +55,27 @@ class Source(BaseModel):
return curated_sources
+def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]:
+ full_response: str = ""
+ stream = completion_gen.response
+ for delta in stream:
+ # if isinstance(delta, str):
+ full_response += str(delta)
+ # elif isinstance(delta, ChatResponse):
+ # full_response += delta.delta or ""
+ yield full_response
+
+ if completion_gen.sources:
+ full_response += SOURCES_SEPARATOR
+ cur_sources = Source.curate_sources(completion_gen.sources)
+ sources_text = "\n\n\n".join(
+ f"{index}. {source.file} (page {source.page})"
+ for index, source in enumerate(cur_sources, start=1)
+ )
+ full_response += sources_text
+ yield full_response
+
+
@singleton
class PrivateGptUi:
@inject
@@ -72,26 +93,6 @@ class PrivateGptUi:
self._ui_block = None
def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any:
- def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]:
- full_response: str = ""
- stream = completion_gen.response
- for delta in stream:
- if isinstance(delta, str):
- full_response += str(delta)
- elif isinstance(delta, ChatResponse):
- full_response += delta.delta or ""
- yield full_response
-
- if completion_gen.sources:
- full_response += SOURCES_SEPARATOR
- cur_sources = Source.curate_sources(completion_gen.sources)
- sources_text = "\n\n\n".join(
- f"{index}. {source.file} (page {source.page})"
- for index, source in enumerate(cur_sources, start=1)
- )
- full_response += sources_text
- yield full_response
-
def build_history() -> list[ChatMessage]:
history_messages: list[ChatMessage] = list(
itertools.chain(
diff --git a/pyproject.toml b/pyproject.toml
index 7bb3ec5e..73ecc710 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -12,7 +12,7 @@ injector = "^0.21.0"
pyyaml = "^6.0.1"
python-multipart = "^0.0.6"
pypdf = "^3.16.2"
-llama-index = { extras = ["local_models"], version = "0.9.3" }
+llama-index = { extras = ["local_models"], version = "0.9.10" }
watchdog = "^3.0.0"
qdrant-client = "^1.6.9"
chromadb = {version = "^0.4.13", optional = true}
@@ -31,7 +31,7 @@ types-pyyaml = "^6.0.12.12"
[tool.poetry.group.ui]
optional = true
[tool.poetry.group.ui.dependencies]
-gradio = "^4.4.1"
+gradio = "^4.7.1"
[tool.poetry.group.local]
optional = true
diff --git a/tests/test_prompt_helper.py b/tests/test_prompt_helper.py
index 1f22a069..c5594691 100644
--- a/tests/test_prompt_helper.py
+++ b/tests/test_prompt_helper.py
@@ -1,10 +1,10 @@
import pytest
from llama_index.llms import ChatMessage, MessageRole
-from private_gpt.components.llm.prompt_helper import (
+from private_gpt.components.llm.prompt.prompt_helper import (
DefaultPromptStyle,
- Llama2PromptStyle,
- TagPromptStyle,
+ LlamaIndexPromptStyle,
+ VigognePromptStyle,
get_prompt_style,
)
@@ -13,8 +13,8 @@ from private_gpt.components.llm.prompt_helper import (
("prompt_style", "expected_prompt_style"),
[
("default", DefaultPromptStyle),
- ("llama2", Llama2PromptStyle),
- ("tag", TagPromptStyle),
+ ("llama2", LlamaIndexPromptStyle),
+ ("tag", VigognePromptStyle),
],
)
def test_get_prompt_style_success(prompt_style, expected_prompt_style):
@@ -29,7 +29,7 @@ def test_get_prompt_style_failure():
def test_tag_prompt_style_format():
- prompt_style = TagPromptStyle()
+ prompt_style = VigognePromptStyle()
messages = [
ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
@@ -46,7 +46,7 @@ def test_tag_prompt_style_format():
def test_tag_prompt_style_format_with_system_prompt():
system_prompt = "This is a system prompt from configuration."
- prompt_style = TagPromptStyle(default_system_prompt=system_prompt)
+ prompt_style = VigognePromptStyle(default_system_prompt=system_prompt)
messages = [
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
]
@@ -76,7 +76,7 @@ def test_tag_prompt_style_format_with_system_prompt():
def test_llama2_prompt_style_format():
- prompt_style = Llama2PromptStyle()
+ prompt_style = LlamaIndexPromptStyle()
messages = [
ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
@@ -95,7 +95,7 @@ def test_llama2_prompt_style_format():
def test_llama2_prompt_style_with_system_prompt():
system_prompt = "This is a system prompt from configuration."
- prompt_style = Llama2PromptStyle(default_system_prompt=system_prompt)
+ prompt_style = LlamaIndexPromptStyle(default_system_prompt=system_prompt)
messages = [
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
]