mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-07-31 23:16:58 +00:00
Fix typing, linting and add tests
This commit is contained in:
parent
76faffb269
commit
5bc5054000
@ -22,12 +22,14 @@ class LLMComponent:
|
||||
case "local":
|
||||
from llama_index.llms import LlamaCPP
|
||||
|
||||
from private_gpt.components.llm.prompt.prompt_helper import get_prompt_style
|
||||
from private_gpt.components.llm.prompt.prompt_helper import (
|
||||
get_prompt_style,
|
||||
)
|
||||
|
||||
prompt_style = get_prompt_style(
|
||||
prompt_style=settings.local.prompt_style,
|
||||
template_name=settings.local.template_name,
|
||||
default_system_prompt=settings.local.default_system_prompt
|
||||
default_system_prompt=settings.local.default_system_prompt,
|
||||
)
|
||||
|
||||
self.llm = LlamaCPP(
|
||||
|
@ -1,5 +1,4 @@
|
||||
"""
|
||||
Helper to get your llama_index messages correctly serialized into a prompt.
|
||||
"""Helper to get your llama_index messages correctly serialized into a prompt.
|
||||
|
||||
This set of classes and functions is used to format a series of
|
||||
llama_index ChatMessage into a prompt (a unique string) that will be passed
|
||||
@ -29,14 +28,13 @@ import abc
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from jinja2 import FileSystemLoader
|
||||
from jinja2.exceptions import TemplateError
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
from llama_cpp import llama_types, Llama
|
||||
from llama_cpp import llama_chat_format
|
||||
from llama_index.llms import ChatMessage, MessageRole, LlamaCPP
|
||||
from llama_cpp import llama_chat_format, llama_types
|
||||
from llama_index.llms import ChatMessage, MessageRole
|
||||
from llama_index.llms.llama_utils import (
|
||||
DEFAULT_SYSTEM_PROMPT,
|
||||
completion_to_prompt,
|
||||
@ -50,7 +48,7 @@ logger = logging.getLogger(__name__)
|
||||
THIS_DIRECTORY_RELATIVE = Path(__file__).parent.relative_to(PROJECT_ROOT_PATH)
|
||||
|
||||
|
||||
_LLAMA_CPP_PYTHON_CHAT_FORMAT = {
|
||||
_LLAMA_CPP_PYTHON_CHAT_FORMAT: dict[str, llama_chat_format.ChatFormatter] = {
|
||||
"llama-2": llama_chat_format.format_llama2,
|
||||
"alpaca": llama_chat_format.format_alpaca,
|
||||
"vicuna": llama_chat_format.format,
|
||||
@ -80,6 +78,7 @@ def llama_index_to_llama_cpp_messages(
|
||||
list of llama_cpp ChatCompletionRequestMessage.
|
||||
"""
|
||||
llama_cpp_messages: list[llama_types.ChatCompletionRequestMessage] = []
|
||||
l_msg: llama_types.ChatCompletionRequestMessage
|
||||
for msg in messages:
|
||||
if msg.role == MessageRole.SYSTEM:
|
||||
l_msg = llama_types.ChatCompletionRequestSystemMessage(
|
||||
@ -112,10 +111,11 @@ def llama_index_to_llama_cpp_messages(
|
||||
|
||||
|
||||
def _get_llama_cpp_chat_format(name: str) -> llama_chat_format.ChatFormatter:
|
||||
logger.debug("Getting llama_cpp_python prompt_format='%s'", name)
|
||||
try:
|
||||
return _LLAMA_CPP_PYTHON_CHAT_FORMAT[name]
|
||||
except KeyError:
|
||||
raise ValueError(f"Unknown llama_cpp_python prompt style '{name}'")
|
||||
except KeyError as err:
|
||||
raise ValueError(f"Unknown llama_cpp_python prompt style '{name}'") from err
|
||||
|
||||
|
||||
class AbstractPromptStyle(abc.ABC):
|
||||
@ -161,18 +161,18 @@ class AbstractPromptStyle(abc.ABC):
|
||||
logger.debug("Got for completion='%s' the prompt='%s'", completion, prompt)
|
||||
return prompt
|
||||
|
||||
def improve_prompt_format(self, llm: LlamaCPP) -> None:
|
||||
"""Improve the prompt format of the given LLM.
|
||||
|
||||
Use the given metadata in the LLM to improve the prompt format.
|
||||
"""
|
||||
# FIXME: below we are getting IDs (1,2,13) from llama.cpp, and not actual strings
|
||||
llama_cpp_llm = cast(Llama, llm._model)
|
||||
self.bos_token = llama_cpp_llm.token_bos()
|
||||
self.eos_token = llama_cpp_llm.token_eos()
|
||||
self.nl_token = llama_cpp_llm.token_nl()
|
||||
print([self.bos_token, self.eos_token, self.nl_token])
|
||||
# (1,2,13) are the IDs of the tokens
|
||||
# def improve_prompt_format(self, llm: LlamaCPP) -> None:
|
||||
# """Improve the prompt format of the given LLM.
|
||||
#
|
||||
# Use the given metadata in the LLM to improve the prompt format.
|
||||
# """
|
||||
# # FIXME: we are getting IDs (1,2,13) from llama.cpp, and not actual strings
|
||||
# llama_cpp_llm = cast(Llama, llm._model)
|
||||
# self.bos_token = llama_cpp_llm.token_bos()
|
||||
# self.eos_token = llama_cpp_llm.token_eos()
|
||||
# self.nl_token = llama_cpp_llm.token_nl()
|
||||
# print([self.bos_token, self.eos_token, self.nl_token])
|
||||
# # (1,2,13) are the IDs of the tokens
|
||||
|
||||
|
||||
class AbstractPromptStyleWithSystemPrompt(AbstractPromptStyle, abc.ABC):
|
||||
@ -183,14 +183,18 @@ class AbstractPromptStyleWithSystemPrompt(AbstractPromptStyle, abc.ABC):
|
||||
logger.debug("Got default_system_prompt='%s'", default_system_prompt)
|
||||
self.default_system_prompt = default_system_prompt
|
||||
|
||||
def _add_missing_system_prompt(self, messages: Sequence[ChatMessage]) -> list[ChatMessage]:
|
||||
def _add_missing_system_prompt(
|
||||
self, messages: Sequence[ChatMessage]
|
||||
) -> Sequence[ChatMessage]:
|
||||
if messages[0].role != MessageRole.SYSTEM:
|
||||
logger.debug(
|
||||
"Adding system_promt='%s' to the given messages as there are none given in the session",
|
||||
self.default_system_prompt,
|
||||
)
|
||||
messages = [
|
||||
ChatMessage(content=self.default_system_prompt, role=MessageRole.SYSTEM),
|
||||
ChatMessage(
|
||||
content=self.default_system_prompt, role=MessageRole.SYSTEM
|
||||
),
|
||||
*messages,
|
||||
]
|
||||
return messages
|
||||
@ -256,7 +260,11 @@ class VigognePromptStyle(AbstractPromptStyleWithSystemPrompt):
|
||||
FIXME: should we add surrounding `<s>` and `</s>` tags, like in llama2?
|
||||
"""
|
||||
|
||||
def __init__(self, default_system_prompt: str | None = None, add_generation_prompt: bool = True) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
default_system_prompt: str | None = None,
|
||||
add_generation_prompt: bool = True,
|
||||
) -> None:
|
||||
# We have to define a default system prompt here as the LLM will not
|
||||
# use the default llama_utils functions.
|
||||
default_system_prompt = default_system_prompt or self._DEFAULT_SYSTEM_PROMPT
|
||||
@ -272,7 +280,7 @@ class VigognePromptStyle(AbstractPromptStyleWithSystemPrompt):
|
||||
messages = [ChatMessage(content=completion, role=MessageRole.USER)]
|
||||
return self._format_messages_to_prompt(messages)
|
||||
|
||||
def _format_messages_to_prompt(self, messages: list[ChatMessage]) -> str:
|
||||
def _format_messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||
# TODO add BOS and EOS TOKEN !!!!! (c.f. jinja template)
|
||||
"""Format message to prompt with `<|ROLE|>: MSG` style."""
|
||||
assert messages[0].role == MessageRole.SYSTEM
|
||||
@ -291,21 +299,23 @@ class VigognePromptStyle(AbstractPromptStyleWithSystemPrompt):
|
||||
|
||||
|
||||
class LlamaCppPromptStyle(AbstractPromptStyleWithSystemPrompt):
|
||||
def __init__(self, prompt_style: str, default_system_prompt: str | None = None) -> None:
|
||||
def __init__(
|
||||
self, prompt_style: str, default_system_prompt: str | None = None
|
||||
) -> None:
|
||||
"""Wrapper for llama_cpp_python defined prompt format.
|
||||
|
||||
:param prompt_style:
|
||||
:param default_system_prompt: Used if no system prompt is given in the messages.
|
||||
"""
|
||||
assert prompt_style.startswith("llama_cpp.")
|
||||
default_system_prompt = default_system_prompt or self._DEFAULT_SYSTEM_PROMPT
|
||||
super().__init__(default_system_prompt)
|
||||
|
||||
self.prompt_style = prompt_style
|
||||
self.prompt_style = prompt_style[len("llama_cpp.") :]
|
||||
if self.prompt_style is None:
|
||||
return
|
||||
|
||||
self._llama_cpp_formatter = _get_llama_cpp_chat_format(self.prompt_style)
|
||||
self.messages_to_prompt = self._messages_to_prompt
|
||||
self.completion_to_prompt = self._completion_to_prompt
|
||||
|
||||
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||
messages = self._add_missing_system_prompt(messages)
|
||||
@ -314,19 +324,28 @@ class LlamaCppPromptStyle(AbstractPromptStyleWithSystemPrompt):
|
||||
).prompt
|
||||
|
||||
def _completion_to_prompt(self, completion: str) -> str:
|
||||
messages = self._add_missing_system_prompt([ChatMessage(content=completion, role=MessageRole.USER)])
|
||||
messages = self._add_missing_system_prompt(
|
||||
[ChatMessage(content=completion, role=MessageRole.USER)]
|
||||
)
|
||||
return self._llama_cpp_formatter(
|
||||
messages=llama_index_to_llama_cpp_messages(messages)
|
||||
).prompt
|
||||
|
||||
|
||||
class TemplatePromptStyle(AbstractPromptStyleWithSystemPrompt):
|
||||
def __init__(self, template_name: str, add_generation_prompt: bool = True,
|
||||
default_system_prompt: str | None = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
template_name: str,
|
||||
template_dir: str | None = None,
|
||||
add_generation_prompt: bool = True,
|
||||
default_system_prompt: str | None = None,
|
||||
) -> None:
|
||||
"""Prompt format using a Jinja template.
|
||||
|
||||
:param template_name: the filename of the template to use, must be in
|
||||
the `./template/` directory.
|
||||
:param template_dir: the directory where the template is located.
|
||||
Defaults to `./template/`.
|
||||
:param default_system_prompt: Used if no system prompt is
|
||||
given in the messages.
|
||||
"""
|
||||
@ -335,12 +354,18 @@ class TemplatePromptStyle(AbstractPromptStyleWithSystemPrompt):
|
||||
|
||||
self._add_generation_prompt = add_generation_prompt
|
||||
|
||||
def raise_exception(message):
|
||||
def raise_exception(message: str) -> None:
|
||||
raise TemplateError(message)
|
||||
|
||||
self._jinja_fs_loader = FileSystemLoader(searchpath=THIS_DIRECTORY_RELATIVE / "template")
|
||||
self._jinja_env = ImmutableSandboxedEnvironment(loader=self._jinja_fs_loader, trim_blocks=True,
|
||||
lstrip_blocks=True)
|
||||
if template_dir is None:
|
||||
self.template_dir = THIS_DIRECTORY_RELATIVE / "template"
|
||||
else:
|
||||
self.template_dir = Path(template_dir)
|
||||
|
||||
self._jinja_fs_loader = FileSystemLoader(searchpath=self.template_dir)
|
||||
self._jinja_env = ImmutableSandboxedEnvironment(
|
||||
loader=self._jinja_fs_loader, trim_blocks=True, lstrip_blocks=True
|
||||
)
|
||||
self._jinja_env.globals["raise_exception"] = raise_exception
|
||||
|
||||
self.template = self._jinja_env.get_template(template_name)
|
||||
@ -368,9 +393,11 @@ class TemplatePromptStyle(AbstractPromptStyleWithSystemPrompt):
|
||||
)
|
||||
|
||||
def _completion_to_prompt(self, completion: str) -> str:
|
||||
messages = self._add_missing_system_prompt([
|
||||
ChatMessage(content=completion, role=MessageRole.USER),
|
||||
])
|
||||
messages = self._add_missing_system_prompt(
|
||||
[
|
||||
ChatMessage(content=completion, role=MessageRole.USER),
|
||||
]
|
||||
)
|
||||
return self._messages_to_prompt(messages)
|
||||
|
||||
|
||||
@ -380,17 +407,16 @@ class TemplatePromptStyle(AbstractPromptStyleWithSystemPrompt):
|
||||
# Pass all the arguments at once
|
||||
def get_prompt_style(
|
||||
prompt_style: str | None,
|
||||
**kwargs: Any,
|
||||
**kwargs: Any,
|
||||
) -> AbstractPromptStyle:
|
||||
"""Get the prompt style to use from the given string.
|
||||
|
||||
:param prompt_style: The prompt style to use.
|
||||
:return: The prompt style to use.
|
||||
"""
|
||||
if prompt_style is None or prompt_style == "default":
|
||||
if prompt_style is None:
|
||||
return DefaultPromptStyle(**kwargs)
|
||||
if prompt_style.startswith("llama_cpp."):
|
||||
prompt_style = prompt_style[len("llama_cpp."):]
|
||||
return LlamaCppPromptStyle(prompt_style, **kwargs)
|
||||
elif prompt_style == "llama2":
|
||||
return LlamaIndexPromptStyle(**kwargs)
|
||||
|
@ -1,9 +1,14 @@
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import pytest
|
||||
from llama_index.llms import ChatMessage, MessageRole
|
||||
|
||||
from private_gpt.components.llm.prompt.prompt_helper import (
|
||||
DefaultPromptStyle,
|
||||
LlamaCppPromptStyle,
|
||||
LlamaIndexPromptStyle,
|
||||
TemplatePromptStyle,
|
||||
VigognePromptStyle,
|
||||
get_prompt_style,
|
||||
)
|
||||
@ -12,13 +17,44 @@ from private_gpt.components.llm.prompt.prompt_helper import (
|
||||
@pytest.mark.parametrize(
|
||||
("prompt_style", "expected_prompt_style"),
|
||||
[
|
||||
("default", DefaultPromptStyle),
|
||||
(None, DefaultPromptStyle),
|
||||
("llama2", LlamaIndexPromptStyle),
|
||||
("tag", VigognePromptStyle),
|
||||
("vigogne", VigognePromptStyle),
|
||||
("llama_cpp.alpaca", LlamaCppPromptStyle),
|
||||
("llama_cpp.zephyr", LlamaCppPromptStyle),
|
||||
],
|
||||
)
|
||||
def test_get_prompt_style_success(prompt_style, expected_prompt_style):
|
||||
assert get_prompt_style(prompt_style) == expected_prompt_style
|
||||
assert type(get_prompt_style(prompt_style)) == expected_prompt_style
|
||||
|
||||
|
||||
def test_get_prompt_style_template_success():
|
||||
jinja_template = "{% for message in messages %}<|{{message['role']}}|>: {{message['content'].strip() + '\\n'}}{% endfor %}<|assistant|>: "
|
||||
with NamedTemporaryFile("w") as tmp_file:
|
||||
path = Path(tmp_file.name)
|
||||
tmp_file.write(jinja_template)
|
||||
tmp_file.flush()
|
||||
tmp_file.seek(0)
|
||||
prompt_style = get_prompt_style(
|
||||
"template", template_name=path.name, template_dir=path.parent
|
||||
)
|
||||
assert type(prompt_style) == TemplatePromptStyle
|
||||
prompt = prompt_style.messages_to_prompt(
|
||||
[
|
||||
ChatMessage(
|
||||
content="You are an AI assistant.", role=MessageRole.SYSTEM
|
||||
),
|
||||
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
|
||||
]
|
||||
)
|
||||
|
||||
expected_prompt = (
|
||||
"<|system|>: You are an AI assistant.\n"
|
||||
"<|user|>: Hello, how are you doing?\n"
|
||||
"<|assistant|>: "
|
||||
)
|
||||
|
||||
assert prompt == expected_prompt
|
||||
|
||||
|
||||
def test_get_prompt_style_failure():
|
||||
|
Loading…
Reference in New Issue
Block a user