WIP more prompt format, and more maintainable

This commit is contained in:
Louis 2023-12-03 00:48:43 +01:00
parent 3d301d0c6f
commit 76faffb269
11 changed files with 476 additions and 217 deletions

View File

@ -4,7 +4,6 @@ from injector import inject, singleton
from llama_index.llms import MockLLM from llama_index.llms import MockLLM
from llama_index.llms.base import LLM from llama_index.llms.base import LLM
from private_gpt.components.llm.prompt_helper import get_prompt_style
from private_gpt.paths import models_path from private_gpt.paths import models_path
from private_gpt.settings.settings import Settings from private_gpt.settings.settings import Settings
@ -23,8 +22,11 @@ class LLMComponent:
case "local": case "local":
from llama_index.llms import LlamaCPP from llama_index.llms import LlamaCPP
prompt_style_cls = get_prompt_style(settings.local.prompt_style) from private_gpt.components.llm.prompt.prompt_helper import get_prompt_style
prompt_style = prompt_style_cls(
prompt_style = get_prompt_style(
prompt_style=settings.local.prompt_style,
template_name=settings.local.template_name,
default_system_prompt=settings.local.default_system_prompt default_system_prompt=settings.local.default_system_prompt
) )
@ -43,6 +45,7 @@ class LLMComponent:
completion_to_prompt=prompt_style.completion_to_prompt, completion_to_prompt=prompt_style.completion_to_prompt,
verbose=True, verbose=True,
) )
# prompt_style.improve_prompt_format(llm=cast(LlamaCPP, self.llm))
case "sagemaker": case "sagemaker":
from private_gpt.components.llm.custom.sagemaker import SagemakerLLM from private_gpt.components.llm.custom.sagemaker import SagemakerLLM

View File

@ -0,0 +1,401 @@
"""
Helper to get your llama_index messages correctly serialized into a prompt.
This set of classes and functions is used to format a series of
llama_index ChatMessage into a prompt (a unique string) that will be passed
as is to the LLM. The LLM will then use this prompt to generate a completion.
There are **MANY** formats for prompts; usually, each model has its own format.
Models posted on HuggingFace usually have a description of the format they use.
The original models, that are shipped through `transformers`, have their
format defined in the file `tokenizer_config.json` in the model's directory.
The prompt format are usually defined as a Jinja template (with some custom
Jinja token definitions). These prompt templates are usable using
the `transformers.AutoTokenizer`, as described in
https://huggingface.co/docs/transformers/main/chat_templating
Examples of `tokenizer_config.json` files:
https://huggingface.co/bofenghuang/vigogne-2-7b-chat/blob/main/tokenizer_config.json
https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
https://huggingface.co/HuggingFaceH4/zephyr-7b-beta/blob/main/tokenizer_config.json
The format of the prompt is important, as if the wrong one is used, it
will lead to "hallucinations" and other completions that are not relevant.
"""
import abc
import logging
from collections.abc import Sequence
from pathlib import Path
from typing import Any, cast
from jinja2 import FileSystemLoader
from jinja2.exceptions import TemplateError
from jinja2.sandbox import ImmutableSandboxedEnvironment
from llama_cpp import llama_types, Llama
from llama_cpp import llama_chat_format
from llama_index.llms import ChatMessage, MessageRole, LlamaCPP
from llama_index.llms.llama_utils import (
DEFAULT_SYSTEM_PROMPT,
completion_to_prompt,
messages_to_prompt,
)
from private_gpt.constants import PROJECT_ROOT_PATH
logger = logging.getLogger(__name__)
THIS_DIRECTORY_RELATIVE = Path(__file__).parent.relative_to(PROJECT_ROOT_PATH)
_LLAMA_CPP_PYTHON_CHAT_FORMAT = {
"llama-2": llama_chat_format.format_llama2,
"alpaca": llama_chat_format.format_alpaca,
"vicuna": llama_chat_format.format,
"oasst_llama": llama_chat_format.format_oasst_llama,
"baichuan-2": llama_chat_format.format_baichuan2,
"baichuan": llama_chat_format.format_baichuan,
"openbuddy": llama_chat_format.format_openbuddy,
"redpajama-incite": llama_chat_format.format_redpajama_incite,
"snoozy": llama_chat_format.format_snoozy,
"phind": llama_chat_format.format_phind,
"intel": llama_chat_format.format_intel,
"open-orca": llama_chat_format.format_open_orca,
"mistrallite": llama_chat_format.format_mistrallite,
"zephyr": llama_chat_format.format_zephyr,
"chatml": llama_chat_format.format_chatml,
"openchat": llama_chat_format.format_openchat,
}
# FIXME partial support
def llama_index_to_llama_cpp_messages(
messages: Sequence[ChatMessage],
) -> list[llama_types.ChatCompletionRequestMessage]:
"""Convert messages from llama_index to llama_cpp format.
Convert a list of llama_index ChatMessage to a
list of llama_cpp ChatCompletionRequestMessage.
"""
llama_cpp_messages: list[llama_types.ChatCompletionRequestMessage] = []
for msg in messages:
if msg.role == MessageRole.SYSTEM:
l_msg = llama_types.ChatCompletionRequestSystemMessage(
content=msg.content, role=msg.role.value
)
elif msg.role == MessageRole.USER:
# FIXME partial support
l_msg = llama_types.ChatCompletionRequestUserMessage(
content=msg.content, role=msg.role.value
)
elif msg.role == MessageRole.ASSISTANT:
# FIXME partial support
l_msg = llama_types.ChatCompletionRequestAssistantMessage(
content=msg.content, role=msg.role.value
)
elif msg.role == MessageRole.TOOL:
# FIXME partial support
l_msg = llama_types.ChatCompletionRequestToolMessage(
content=msg.content, role=msg.role.value, tool_call_id=""
)
elif msg.role == MessageRole.FUNCTION:
# FIXME partial support
l_msg = llama_types.ChatCompletionRequestFunctionMessage(
content=msg.content, role=msg.role.value, name=""
)
else:
raise ValueError(f"Unknown role='{msg.role}'")
llama_cpp_messages.append(l_msg)
return llama_cpp_messages
def _get_llama_cpp_chat_format(name: str) -> llama_chat_format.ChatFormatter:
try:
return _LLAMA_CPP_PYTHON_CHAT_FORMAT[name]
except KeyError:
raise ValueError(f"Unknown llama_cpp_python prompt style '{name}'")
class AbstractPromptStyle(abc.ABC):
"""Abstract class for prompt styles.
This class is used to format a series of messages into a prompt that can be
understood by the models. A series of messages represents the interaction(s)
between a user and an assistant. This series of messages can be considered as a
session between a user X and an assistant Y.This session holds, through the
messages, the state of the conversation. This session, to be understood by the
model, needs to be formatted into a prompt (i.e. a string that the models
can understand). Prompts can be formatted in different ways,
depending on the model.
The implementations of this class represent the different ways to format a
series of messages into a prompt.
"""
@abc.abstractmethod
def __init__(self, *args: Any, **kwargs: Any) -> None:
logger.debug("Initializing prompt_style=%s", self.__class__.__name__)
self.bos_token = "<s>"
self.eos_token = "</s>"
self.nl_token = "\n"
@abc.abstractmethod
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
pass
@abc.abstractmethod
def _completion_to_prompt(self, completion: str) -> str:
pass
def messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
logger.debug("Formatting messages='%s' to prompt", messages)
prompt = self._messages_to_prompt(messages)
logger.debug("Got for messages='%s' the prompt='%s'", messages, prompt)
return prompt
def completion_to_prompt(self, completion: str) -> str:
logger.debug("Formatting completion='%s' to prompt", completion)
prompt = self._completion_to_prompt(completion)
logger.debug("Got for completion='%s' the prompt='%s'", completion, prompt)
return prompt
def improve_prompt_format(self, llm: LlamaCPP) -> None:
"""Improve the prompt format of the given LLM.
Use the given metadata in the LLM to improve the prompt format.
"""
# FIXME: below we are getting IDs (1,2,13) from llama.cpp, and not actual strings
llama_cpp_llm = cast(Llama, llm._model)
self.bos_token = llama_cpp_llm.token_bos()
self.eos_token = llama_cpp_llm.token_eos()
self.nl_token = llama_cpp_llm.token_nl()
print([self.bos_token, self.eos_token, self.nl_token])
# (1,2,13) are the IDs of the tokens
class AbstractPromptStyleWithSystemPrompt(AbstractPromptStyle, abc.ABC):
_DEFAULT_SYSTEM_PROMPT = DEFAULT_SYSTEM_PROMPT
def __init__(self, default_system_prompt: str | None) -> None:
super().__init__()
logger.debug("Got default_system_prompt='%s'", default_system_prompt)
self.default_system_prompt = default_system_prompt
def _add_missing_system_prompt(self, messages: Sequence[ChatMessage]) -> list[ChatMessage]:
if messages[0].role != MessageRole.SYSTEM:
logger.debug(
"Adding system_promt='%s' to the given messages as there are none given in the session",
self.default_system_prompt,
)
messages = [
ChatMessage(content=self.default_system_prompt, role=MessageRole.SYSTEM),
*messages,
]
return messages
class DefaultPromptStyle(AbstractPromptStyle):
"""Default prompt style that uses the defaults from llama_utils.
It basically passes None to the LLM, indicating it should use
the default functions.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
# Hacky way to override the functions
# Override the functions to be None, and pass None to the LLM.
self.messages_to_prompt = None # type: ignore[method-assign, assignment]
self.completion_to_prompt = None # type: ignore[method-assign, assignment]
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
"""Dummy implementation."""
return ""
def _completion_to_prompt(self, completion: str) -> str:
"""Dummy implementation."""
return ""
class LlamaIndexPromptStyle(AbstractPromptStyleWithSystemPrompt):
"""Simple prompt style that just uses the default llama_utils functions.
It transforms the sequence of messages into a prompt that should look like:
```text
<s> [INST] <<SYS>> your system prompt here. <</SYS>>
user message here [/INST] assistant (model) response here </s>
```
"""
def __init__(self, default_system_prompt: str | None = None) -> None:
# If no system prompt is given, the default one of the implementation is used.
super().__init__(default_system_prompt=default_system_prompt)
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
return messages_to_prompt(messages, self.default_system_prompt)
def _completion_to_prompt(self, completion: str) -> str:
return completion_to_prompt(completion, self.default_system_prompt)
class VigognePromptStyle(AbstractPromptStyleWithSystemPrompt):
"""Tag prompt style (used by Vigogne) that uses the prompt style `<|ROLE|>`.
It transforms the sequence of messages into a prompt that should look like:
```text
<|system|>: your system prompt here.
<|user|>: user message here
(possibly with context and question)
<|assistant|>: assistant (model) response here.
```
FIXME: should we add surrounding `<s>` and `</s>` tags, like in llama2?
"""
def __init__(self, default_system_prompt: str | None = None, add_generation_prompt: bool = True) -> None:
# We have to define a default system prompt here as the LLM will not
# use the default llama_utils functions.
default_system_prompt = default_system_prompt or self._DEFAULT_SYSTEM_PROMPT
super().__init__(default_system_prompt)
self.system_prompt: str = default_system_prompt
self.add_generation_prompt = add_generation_prompt
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
messages = self._add_missing_system_prompt(messages)
return self._format_messages_to_prompt(messages)
def _completion_to_prompt(self, completion: str) -> str:
messages = [ChatMessage(content=completion, role=MessageRole.USER)]
return self._format_messages_to_prompt(messages)
def _format_messages_to_prompt(self, messages: list[ChatMessage]) -> str:
# TODO add BOS and EOS TOKEN !!!!! (c.f. jinja template)
"""Format message to prompt with `<|ROLE|>: MSG` style."""
assert messages[0].role == MessageRole.SYSTEM
prompt = ""
# TODO enclose the interaction between self.token_bos and self.token_eos
for message in messages:
role = message.role
content = message.content or ""
message_from_user = f"<|{role.lower()}|>: {content.strip()}"
message_from_user += self.nl_token
prompt += message_from_user
if self.add_generation_prompt:
# we are missing the last <|assistant|> tag that will trigger a completion
prompt += "<|assistant|>: "
return prompt
class LlamaCppPromptStyle(AbstractPromptStyleWithSystemPrompt):
def __init__(self, prompt_style: str, default_system_prompt: str | None = None) -> None:
"""Wrapper for llama_cpp_python defined prompt format.
:param prompt_style:
:param default_system_prompt: Used if no system prompt is given in the messages.
"""
default_system_prompt = default_system_prompt or self._DEFAULT_SYSTEM_PROMPT
super().__init__(default_system_prompt)
self.prompt_style = prompt_style
if self.prompt_style is None:
return
self._llama_cpp_formatter = _get_llama_cpp_chat_format(self.prompt_style)
self.messages_to_prompt = self._messages_to_prompt
self.completion_to_prompt = self._completion_to_prompt
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
messages = self._add_missing_system_prompt(messages)
return self._llama_cpp_formatter(
messages=llama_index_to_llama_cpp_messages(messages)
).prompt
def _completion_to_prompt(self, completion: str) -> str:
messages = self._add_missing_system_prompt([ChatMessage(content=completion, role=MessageRole.USER)])
return self._llama_cpp_formatter(
messages=llama_index_to_llama_cpp_messages(messages)
).prompt
class TemplatePromptStyle(AbstractPromptStyleWithSystemPrompt):
def __init__(self, template_name: str, add_generation_prompt: bool = True,
default_system_prompt: str | None = None) -> None:
"""Prompt format using a Jinja template.
:param template_name: the filename of the template to use, must be in
the `./template/` directory.
:param default_system_prompt: Used if no system prompt is
given in the messages.
"""
default_system_prompt = default_system_prompt or DEFAULT_SYSTEM_PROMPT
super().__init__(default_system_prompt)
self._add_generation_prompt = add_generation_prompt
def raise_exception(message):
raise TemplateError(message)
self._jinja_fs_loader = FileSystemLoader(searchpath=THIS_DIRECTORY_RELATIVE / "template")
self._jinja_env = ImmutableSandboxedEnvironment(loader=self._jinja_fs_loader, trim_blocks=True,
lstrip_blocks=True)
self._jinja_env.globals["raise_exception"] = raise_exception
self.template = self._jinja_env.get_template(template_name)
@property
def _extra_kwargs_render(self) -> dict[str, Any]:
return {
"eos_token": self.eos_token,
"bos_token": self.bos_token,
"nl_token": self.nl_token,
}
@staticmethod
def _j_raise_exception(x: str) -> None:
"""Helper method to let Jinja template raise exceptions."""
raise RuntimeError(x)
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
messages = self._add_missing_system_prompt(messages)
msgs = [{"role": msg.role.value, "content": msg.content} for msg in messages]
return self.template.render(
messages=msgs,
add_generation_prompt=self._add_generation_prompt,
**self._extra_kwargs_render,
)
def _completion_to_prompt(self, completion: str) -> str:
messages = self._add_missing_system_prompt([
ChatMessage(content=completion, role=MessageRole.USER),
])
return self._messages_to_prompt(messages)
# TODO Maybe implement an auto-prompt style?
# Pass all the arguments at once
def get_prompt_style(
prompt_style: str | None,
**kwargs: Any,
) -> AbstractPromptStyle:
"""Get the prompt style to use from the given string.
:param prompt_style: The prompt style to use.
:return: The prompt style to use.
"""
if prompt_style is None or prompt_style == "default":
return DefaultPromptStyle(**kwargs)
if prompt_style.startswith("llama_cpp."):
prompt_style = prompt_style[len("llama_cpp."):]
return LlamaCppPromptStyle(prompt_style, **kwargs)
elif prompt_style == "llama2":
return LlamaIndexPromptStyle(**kwargs)
elif prompt_style == "vigogne":
return VigognePromptStyle(**kwargs)
elif prompt_style == "template":
return TemplatePromptStyle(**kwargs)
raise ValueError(f"Unknown prompt_style='{prompt_style}'")

View File

@ -0,0 +1,2 @@
{# This template is coming from: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json #}
{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}

View File

@ -0,0 +1,2 @@
{# This template is coming from: https://huggingface.co/bofenghuang/vigogne-2-7b-chat/blob/main/tokenizer_config.json #}
{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|system|>: ' + system_message + '\\n' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '<|user|>: ' + message['content'].strip() + '\\n' }}{% elif message['role'] == 'assistant' %}{{ '<|assistant|>: ' + message['content'].strip() + eos_token + '\\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>:' }}{% endif %}

View File

@ -0,0 +1,2 @@
{# This template is coming from: https://huggingface.co/HuggingFaceH4/zephyr-7b-beta/blob/main/tokenizer_config.json #}
{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}

View File

@ -1,179 +0,0 @@
import abc
import logging
from collections.abc import Sequence
from typing import Any, Literal
from llama_index.llms import ChatMessage, MessageRole
from llama_index.llms.llama_utils import (
DEFAULT_SYSTEM_PROMPT,
completion_to_prompt,
messages_to_prompt,
)
logger = logging.getLogger(__name__)
class AbstractPromptStyle(abc.ABC):
"""Abstract class for prompt styles.
This class is used to format a series of messages into a prompt that can be
understood by the models. A series of messages represents the interaction(s)
between a user and an assistant. This series of messages can be considered as a
session between a user X and an assistant Y.This session holds, through the
messages, the state of the conversation. This session, to be understood by the
model, needs to be formatted into a prompt (i.e. a string that the models
can understand). Prompts can be formatted in different ways,
depending on the model.
The implementations of this class represent the different ways to format a
series of messages into a prompt.
"""
@abc.abstractmethod
def __init__(self, *args: Any, **kwargs: Any) -> None:
logger.debug("Initializing prompt_style=%s", self.__class__.__name__)
@abc.abstractmethod
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
pass
@abc.abstractmethod
def _completion_to_prompt(self, completion: str) -> str:
pass
def messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
prompt = self._messages_to_prompt(messages)
logger.debug("Got for messages='%s' the prompt='%s'", messages, prompt)
return prompt
def completion_to_prompt(self, completion: str) -> str:
prompt = self._completion_to_prompt(completion)
logger.debug("Got for completion='%s' the prompt='%s'", completion, prompt)
return prompt
class AbstractPromptStyleWithSystemPrompt(AbstractPromptStyle, abc.ABC):
_DEFAULT_SYSTEM_PROMPT = DEFAULT_SYSTEM_PROMPT
def __init__(self, default_system_prompt: str | None) -> None:
super().__init__()
logger.debug("Got default_system_prompt='%s'", default_system_prompt)
self.default_system_prompt = default_system_prompt
class DefaultPromptStyle(AbstractPromptStyle):
"""Default prompt style that uses the defaults from llama_utils.
It basically passes None to the LLM, indicating it should use
the default functions.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
# Hacky way to override the functions
# Override the functions to be None, and pass None to the LLM.
self.messages_to_prompt = None # type: ignore[method-assign, assignment]
self.completion_to_prompt = None # type: ignore[method-assign, assignment]
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
return ""
def _completion_to_prompt(self, completion: str) -> str:
return ""
class Llama2PromptStyle(AbstractPromptStyleWithSystemPrompt):
"""Simple prompt style that just uses the default llama_utils functions.
It transforms the sequence of messages into a prompt that should look like:
```text
<s> [INST] <<SYS>> your system prompt here. <</SYS>>
user message here [/INST] assistant (model) response here </s>
```
"""
def __init__(self, default_system_prompt: str | None = None) -> None:
# If no system prompt is given, the default one of the implementation is used.
super().__init__(default_system_prompt=default_system_prompt)
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
return messages_to_prompt(messages, self.default_system_prompt)
def _completion_to_prompt(self, completion: str) -> str:
return completion_to_prompt(completion, self.default_system_prompt)
class TagPromptStyle(AbstractPromptStyleWithSystemPrompt):
"""Tag prompt style (used by Vigogne) that uses the prompt style `<|ROLE|>`.
It transforms the sequence of messages into a prompt that should look like:
```text
<|system|>: your system prompt here.
<|user|>: user message here
(possibly with context and question)
<|assistant|>: assistant (model) response here.
```
FIXME: should we add surrounding `<s>` and `</s>` tags, like in llama2?
"""
def __init__(self, default_system_prompt: str | None = None) -> None:
# We have to define a default system prompt here as the LLM will not
# use the default llama_utils functions.
default_system_prompt = default_system_prompt or self._DEFAULT_SYSTEM_PROMPT
super().__init__(default_system_prompt)
self.system_prompt: str = default_system_prompt
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
messages = list(messages)
if messages[0].role != MessageRole.SYSTEM:
logger.info(
"Adding system_promt='%s' to the given messages as there are none given in the session",
self.system_prompt,
)
messages = [
ChatMessage(content=self.system_prompt, role=MessageRole.SYSTEM),
*messages,
]
return self._format_messages_to_prompt(messages)
def _completion_to_prompt(self, completion: str) -> str:
return (
f"<|system|>: {self.system_prompt.strip()}\n"
f"<|user|>: {completion.strip()}\n"
"<|assistant|>: "
)
@staticmethod
def _format_messages_to_prompt(messages: list[ChatMessage]) -> str:
"""Format message to prompt with `<|ROLE|>: MSG` style."""
assert messages[0].role == MessageRole.SYSTEM
prompt = ""
for message in messages:
role = message.role
content = message.content or ""
message_from_user = f"<|{role.lower()}|>: {content.strip()}"
message_from_user += "\n"
prompt += message_from_user
# we are missing the last <|assistant|> tag that will trigger a completion
prompt += "<|assistant|>: "
return prompt
def get_prompt_style(
prompt_style: Literal["default", "llama2", "tag"] | None
) -> type[AbstractPromptStyle]:
"""Get the prompt style to use from the given string.
:param prompt_style: The prompt style to use.
:return: The prompt style to use.
"""
if prompt_style is None or prompt_style == "default":
return DefaultPromptStyle
elif prompt_style == "llama2":
return Llama2PromptStyle
elif prompt_style == "tag":
return TagPromptStyle
raise ValueError(f"Unknown prompt_style='{prompt_style}'")

View File

@ -98,13 +98,33 @@ class LocalSettings(BaseModel):
embedding_hf_model_name: str = Field( embedding_hf_model_name: str = Field(
description="Name of the HuggingFace model to use for embeddings" description="Name of the HuggingFace model to use for embeddings"
) )
prompt_style: Literal["default", "llama2", "tag"] = Field( prompt_style: Literal[
"llama_cpp.llama-2",
"llama_cpp.alpaca",
"llama_cpp.vicuna",
"llama_cpp.oasst_llama",
"llama_cpp.baichuan-2",
"llama_cpp.baichuan",
"llama_cpp.openbuddy",
"llama_cpp.redpajama-incite",
"llama_cpp.snoozy",
"llama_cpp.phind",
"llama_cpp.intel",
"llama_cpp.open-orca",
"llama_cpp.mistrallite",
"llama_cpp.zephyr",
"llama_cpp.chatml",
"llama_cpp.openchat",
"llama2", "llama2",
"vigogne",
"template",
] | None = Field(
None,
description=( description=(
"The prompt style to use for the chat engine. " "The prompt style to use for the chat engine. "
"If `default` - use the default prompt style from the llama_index. It should look like `role: message`.\n" "If None is given - use the default prompt style from the llama_index. It should look like `role: message`.\n"
"If `llama2` - use the llama2 prompt style from the llama_index. Based on `<s>`, `[INST]` and `<<SYS>>`.\n" "If `llama2` - use the llama2 prompt style from the llama_index. Based on `<s>`, `[INST]` and `<<SYS>>`.\n"
"If `tag` - use the `tag` prompt style. It should look like `<|role|>: message`. \n" "If `llama_cpp.<name>` - use the `<name>` prompt style, implemented by `llama-cpp-python`. \n"
"`llama2` is the historic behaviour. `default` might work better with your custom models." "`llama2` is the historic behaviour. `default` might work better with your custom models."
), ),
) )
@ -118,6 +138,13 @@ class LocalSettings(BaseModel):
), ),
) )
template_name: str | None = Field(
None,
description=(
"The name of the template to use for the chat engine, if the `prompt_style` is `template`."
),
)
class EmbeddingSettings(BaseModel): class EmbeddingSettings(BaseModel):
mode: Literal["local", "openai", "sagemaker", "mock"] mode: Literal["local", "openai", "sagemaker", "mock"]

View File

@ -9,7 +9,7 @@ import gradio as gr # type: ignore
from fastapi import FastAPI from fastapi import FastAPI
from gradio.themes.utils.colors import slate # type: ignore from gradio.themes.utils.colors import slate # type: ignore
from injector import inject, singleton from injector import inject, singleton
from llama_index.llms import ChatMessage, ChatResponse, MessageRole from llama_index.llms import ChatMessage, MessageRole
from pydantic import BaseModel from pydantic import BaseModel
from private_gpt.constants import PROJECT_ROOT_PATH from private_gpt.constants import PROJECT_ROOT_PATH
@ -55,6 +55,27 @@ class Source(BaseModel):
return curated_sources return curated_sources
def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]:
full_response: str = ""
stream = completion_gen.response
for delta in stream:
# if isinstance(delta, str):
full_response += str(delta)
# elif isinstance(delta, ChatResponse):
# full_response += delta.delta or ""
yield full_response
if completion_gen.sources:
full_response += SOURCES_SEPARATOR
cur_sources = Source.curate_sources(completion_gen.sources)
sources_text = "\n\n\n".join(
f"{index}. {source.file} (page {source.page})"
for index, source in enumerate(cur_sources, start=1)
)
full_response += sources_text
yield full_response
@singleton @singleton
class PrivateGptUi: class PrivateGptUi:
@inject @inject
@ -72,26 +93,6 @@ class PrivateGptUi:
self._ui_block = None self._ui_block = None
def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any: def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any:
def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]:
full_response: str = ""
stream = completion_gen.response
for delta in stream:
if isinstance(delta, str):
full_response += str(delta)
elif isinstance(delta, ChatResponse):
full_response += delta.delta or ""
yield full_response
if completion_gen.sources:
full_response += SOURCES_SEPARATOR
cur_sources = Source.curate_sources(completion_gen.sources)
sources_text = "\n\n\n".join(
f"{index}. {source.file} (page {source.page})"
for index, source in enumerate(cur_sources, start=1)
)
full_response += sources_text
yield full_response
def build_history() -> list[ChatMessage]: def build_history() -> list[ChatMessage]:
history_messages: list[ChatMessage] = list( history_messages: list[ChatMessage] = list(
itertools.chain( itertools.chain(

View File

@ -12,7 +12,7 @@ injector = "^0.21.0"
pyyaml = "^6.0.1" pyyaml = "^6.0.1"
python-multipart = "^0.0.6" python-multipart = "^0.0.6"
pypdf = "^3.16.2" pypdf = "^3.16.2"
llama-index = { extras = ["local_models"], version = "0.9.3" } llama-index = { extras = ["local_models"], version = "0.9.10" }
watchdog = "^3.0.0" watchdog = "^3.0.0"
qdrant-client = "^1.6.9" qdrant-client = "^1.6.9"
chromadb = {version = "^0.4.13", optional = true} chromadb = {version = "^0.4.13", optional = true}
@ -31,7 +31,7 @@ types-pyyaml = "^6.0.12.12"
[tool.poetry.group.ui] [tool.poetry.group.ui]
optional = true optional = true
[tool.poetry.group.ui.dependencies] [tool.poetry.group.ui.dependencies]
gradio = "^4.4.1" gradio = "^4.7.1"
[tool.poetry.group.local] [tool.poetry.group.local]
optional = true optional = true

View File

@ -1,10 +1,10 @@
import pytest import pytest
from llama_index.llms import ChatMessage, MessageRole from llama_index.llms import ChatMessage, MessageRole
from private_gpt.components.llm.prompt_helper import ( from private_gpt.components.llm.prompt.prompt_helper import (
DefaultPromptStyle, DefaultPromptStyle,
Llama2PromptStyle, LlamaIndexPromptStyle,
TagPromptStyle, VigognePromptStyle,
get_prompt_style, get_prompt_style,
) )
@ -13,8 +13,8 @@ from private_gpt.components.llm.prompt_helper import (
("prompt_style", "expected_prompt_style"), ("prompt_style", "expected_prompt_style"),
[ [
("default", DefaultPromptStyle), ("default", DefaultPromptStyle),
("llama2", Llama2PromptStyle), ("llama2", LlamaIndexPromptStyle),
("tag", TagPromptStyle), ("tag", VigognePromptStyle),
], ],
) )
def test_get_prompt_style_success(prompt_style, expected_prompt_style): def test_get_prompt_style_success(prompt_style, expected_prompt_style):
@ -29,7 +29,7 @@ def test_get_prompt_style_failure():
def test_tag_prompt_style_format(): def test_tag_prompt_style_format():
prompt_style = TagPromptStyle() prompt_style = VigognePromptStyle()
messages = [ messages = [
ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM), ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
@ -46,7 +46,7 @@ def test_tag_prompt_style_format():
def test_tag_prompt_style_format_with_system_prompt(): def test_tag_prompt_style_format_with_system_prompt():
system_prompt = "This is a system prompt from configuration." system_prompt = "This is a system prompt from configuration."
prompt_style = TagPromptStyle(default_system_prompt=system_prompt) prompt_style = VigognePromptStyle(default_system_prompt=system_prompt)
messages = [ messages = [
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
] ]
@ -76,7 +76,7 @@ def test_tag_prompt_style_format_with_system_prompt():
def test_llama2_prompt_style_format(): def test_llama2_prompt_style_format():
prompt_style = Llama2PromptStyle() prompt_style = LlamaIndexPromptStyle()
messages = [ messages = [
ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM), ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
@ -95,7 +95,7 @@ def test_llama2_prompt_style_format():
def test_llama2_prompt_style_with_system_prompt(): def test_llama2_prompt_style_with_system_prompt():
system_prompt = "This is a system prompt from configuration." system_prompt = "This is a system prompt from configuration."
prompt_style = Llama2PromptStyle(default_system_prompt=system_prompt) prompt_style = LlamaIndexPromptStyle(default_system_prompt=system_prompt)
messages = [ messages = [
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
] ]