mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-08-08 02:53:41 +00:00
Fix typing, linting and add tests
This commit is contained in:
parent
76faffb269
commit
5bc5054000
@ -22,12 +22,14 @@ class LLMComponent:
|
|||||||
case "local":
|
case "local":
|
||||||
from llama_index.llms import LlamaCPP
|
from llama_index.llms import LlamaCPP
|
||||||
|
|
||||||
from private_gpt.components.llm.prompt.prompt_helper import get_prompt_style
|
from private_gpt.components.llm.prompt.prompt_helper import (
|
||||||
|
get_prompt_style,
|
||||||
|
)
|
||||||
|
|
||||||
prompt_style = get_prompt_style(
|
prompt_style = get_prompt_style(
|
||||||
prompt_style=settings.local.prompt_style,
|
prompt_style=settings.local.prompt_style,
|
||||||
template_name=settings.local.template_name,
|
template_name=settings.local.template_name,
|
||||||
default_system_prompt=settings.local.default_system_prompt
|
default_system_prompt=settings.local.default_system_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.llm = LlamaCPP(
|
self.llm = LlamaCPP(
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
"""
|
"""Helper to get your llama_index messages correctly serialized into a prompt.
|
||||||
Helper to get your llama_index messages correctly serialized into a prompt.
|
|
||||||
|
|
||||||
This set of classes and functions is used to format a series of
|
This set of classes and functions is used to format a series of
|
||||||
llama_index ChatMessage into a prompt (a unique string) that will be passed
|
llama_index ChatMessage into a prompt (a unique string) that will be passed
|
||||||
@ -29,14 +28,13 @@ import abc
|
|||||||
import logging
|
import logging
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, cast
|
from typing import Any
|
||||||
|
|
||||||
from jinja2 import FileSystemLoader
|
from jinja2 import FileSystemLoader
|
||||||
from jinja2.exceptions import TemplateError
|
from jinja2.exceptions import TemplateError
|
||||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||||
from llama_cpp import llama_types, Llama
|
from llama_cpp import llama_chat_format, llama_types
|
||||||
from llama_cpp import llama_chat_format
|
from llama_index.llms import ChatMessage, MessageRole
|
||||||
from llama_index.llms import ChatMessage, MessageRole, LlamaCPP
|
|
||||||
from llama_index.llms.llama_utils import (
|
from llama_index.llms.llama_utils import (
|
||||||
DEFAULT_SYSTEM_PROMPT,
|
DEFAULT_SYSTEM_PROMPT,
|
||||||
completion_to_prompt,
|
completion_to_prompt,
|
||||||
@ -50,7 +48,7 @@ logger = logging.getLogger(__name__)
|
|||||||
THIS_DIRECTORY_RELATIVE = Path(__file__).parent.relative_to(PROJECT_ROOT_PATH)
|
THIS_DIRECTORY_RELATIVE = Path(__file__).parent.relative_to(PROJECT_ROOT_PATH)
|
||||||
|
|
||||||
|
|
||||||
_LLAMA_CPP_PYTHON_CHAT_FORMAT = {
|
_LLAMA_CPP_PYTHON_CHAT_FORMAT: dict[str, llama_chat_format.ChatFormatter] = {
|
||||||
"llama-2": llama_chat_format.format_llama2,
|
"llama-2": llama_chat_format.format_llama2,
|
||||||
"alpaca": llama_chat_format.format_alpaca,
|
"alpaca": llama_chat_format.format_alpaca,
|
||||||
"vicuna": llama_chat_format.format,
|
"vicuna": llama_chat_format.format,
|
||||||
@ -80,6 +78,7 @@ def llama_index_to_llama_cpp_messages(
|
|||||||
list of llama_cpp ChatCompletionRequestMessage.
|
list of llama_cpp ChatCompletionRequestMessage.
|
||||||
"""
|
"""
|
||||||
llama_cpp_messages: list[llama_types.ChatCompletionRequestMessage] = []
|
llama_cpp_messages: list[llama_types.ChatCompletionRequestMessage] = []
|
||||||
|
l_msg: llama_types.ChatCompletionRequestMessage
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if msg.role == MessageRole.SYSTEM:
|
if msg.role == MessageRole.SYSTEM:
|
||||||
l_msg = llama_types.ChatCompletionRequestSystemMessage(
|
l_msg = llama_types.ChatCompletionRequestSystemMessage(
|
||||||
@ -112,10 +111,11 @@ def llama_index_to_llama_cpp_messages(
|
|||||||
|
|
||||||
|
|
||||||
def _get_llama_cpp_chat_format(name: str) -> llama_chat_format.ChatFormatter:
|
def _get_llama_cpp_chat_format(name: str) -> llama_chat_format.ChatFormatter:
|
||||||
|
logger.debug("Getting llama_cpp_python prompt_format='%s'", name)
|
||||||
try:
|
try:
|
||||||
return _LLAMA_CPP_PYTHON_CHAT_FORMAT[name]
|
return _LLAMA_CPP_PYTHON_CHAT_FORMAT[name]
|
||||||
except KeyError:
|
except KeyError as err:
|
||||||
raise ValueError(f"Unknown llama_cpp_python prompt style '{name}'")
|
raise ValueError(f"Unknown llama_cpp_python prompt style '{name}'") from err
|
||||||
|
|
||||||
|
|
||||||
class AbstractPromptStyle(abc.ABC):
|
class AbstractPromptStyle(abc.ABC):
|
||||||
@ -161,18 +161,18 @@ class AbstractPromptStyle(abc.ABC):
|
|||||||
logger.debug("Got for completion='%s' the prompt='%s'", completion, prompt)
|
logger.debug("Got for completion='%s' the prompt='%s'", completion, prompt)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def improve_prompt_format(self, llm: LlamaCPP) -> None:
|
# def improve_prompt_format(self, llm: LlamaCPP) -> None:
|
||||||
"""Improve the prompt format of the given LLM.
|
# """Improve the prompt format of the given LLM.
|
||||||
|
#
|
||||||
Use the given metadata in the LLM to improve the prompt format.
|
# Use the given metadata in the LLM to improve the prompt format.
|
||||||
"""
|
# """
|
||||||
# FIXME: below we are getting IDs (1,2,13) from llama.cpp, and not actual strings
|
# # FIXME: we are getting IDs (1,2,13) from llama.cpp, and not actual strings
|
||||||
llama_cpp_llm = cast(Llama, llm._model)
|
# llama_cpp_llm = cast(Llama, llm._model)
|
||||||
self.bos_token = llama_cpp_llm.token_bos()
|
# self.bos_token = llama_cpp_llm.token_bos()
|
||||||
self.eos_token = llama_cpp_llm.token_eos()
|
# self.eos_token = llama_cpp_llm.token_eos()
|
||||||
self.nl_token = llama_cpp_llm.token_nl()
|
# self.nl_token = llama_cpp_llm.token_nl()
|
||||||
print([self.bos_token, self.eos_token, self.nl_token])
|
# print([self.bos_token, self.eos_token, self.nl_token])
|
||||||
# (1,2,13) are the IDs of the tokens
|
# # (1,2,13) are the IDs of the tokens
|
||||||
|
|
||||||
|
|
||||||
class AbstractPromptStyleWithSystemPrompt(AbstractPromptStyle, abc.ABC):
|
class AbstractPromptStyleWithSystemPrompt(AbstractPromptStyle, abc.ABC):
|
||||||
@ -183,14 +183,18 @@ class AbstractPromptStyleWithSystemPrompt(AbstractPromptStyle, abc.ABC):
|
|||||||
logger.debug("Got default_system_prompt='%s'", default_system_prompt)
|
logger.debug("Got default_system_prompt='%s'", default_system_prompt)
|
||||||
self.default_system_prompt = default_system_prompt
|
self.default_system_prompt = default_system_prompt
|
||||||
|
|
||||||
def _add_missing_system_prompt(self, messages: Sequence[ChatMessage]) -> list[ChatMessage]:
|
def _add_missing_system_prompt(
|
||||||
|
self, messages: Sequence[ChatMessage]
|
||||||
|
) -> Sequence[ChatMessage]:
|
||||||
if messages[0].role != MessageRole.SYSTEM:
|
if messages[0].role != MessageRole.SYSTEM:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Adding system_promt='%s' to the given messages as there are none given in the session",
|
"Adding system_promt='%s' to the given messages as there are none given in the session",
|
||||||
self.default_system_prompt,
|
self.default_system_prompt,
|
||||||
)
|
)
|
||||||
messages = [
|
messages = [
|
||||||
ChatMessage(content=self.default_system_prompt, role=MessageRole.SYSTEM),
|
ChatMessage(
|
||||||
|
content=self.default_system_prompt, role=MessageRole.SYSTEM
|
||||||
|
),
|
||||||
*messages,
|
*messages,
|
||||||
]
|
]
|
||||||
return messages
|
return messages
|
||||||
@ -256,7 +260,11 @@ class VigognePromptStyle(AbstractPromptStyleWithSystemPrompt):
|
|||||||
FIXME: should we add surrounding `<s>` and `</s>` tags, like in llama2?
|
FIXME: should we add surrounding `<s>` and `</s>` tags, like in llama2?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, default_system_prompt: str | None = None, add_generation_prompt: bool = True) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
default_system_prompt: str | None = None,
|
||||||
|
add_generation_prompt: bool = True,
|
||||||
|
) -> None:
|
||||||
# We have to define a default system prompt here as the LLM will not
|
# We have to define a default system prompt here as the LLM will not
|
||||||
# use the default llama_utils functions.
|
# use the default llama_utils functions.
|
||||||
default_system_prompt = default_system_prompt or self._DEFAULT_SYSTEM_PROMPT
|
default_system_prompt = default_system_prompt or self._DEFAULT_SYSTEM_PROMPT
|
||||||
@ -272,7 +280,7 @@ class VigognePromptStyle(AbstractPromptStyleWithSystemPrompt):
|
|||||||
messages = [ChatMessage(content=completion, role=MessageRole.USER)]
|
messages = [ChatMessage(content=completion, role=MessageRole.USER)]
|
||||||
return self._format_messages_to_prompt(messages)
|
return self._format_messages_to_prompt(messages)
|
||||||
|
|
||||||
def _format_messages_to_prompt(self, messages: list[ChatMessage]) -> str:
|
def _format_messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||||
# TODO add BOS and EOS TOKEN !!!!! (c.f. jinja template)
|
# TODO add BOS and EOS TOKEN !!!!! (c.f. jinja template)
|
||||||
"""Format message to prompt with `<|ROLE|>: MSG` style."""
|
"""Format message to prompt with `<|ROLE|>: MSG` style."""
|
||||||
assert messages[0].role == MessageRole.SYSTEM
|
assert messages[0].role == MessageRole.SYSTEM
|
||||||
@ -291,21 +299,23 @@ class VigognePromptStyle(AbstractPromptStyleWithSystemPrompt):
|
|||||||
|
|
||||||
|
|
||||||
class LlamaCppPromptStyle(AbstractPromptStyleWithSystemPrompt):
|
class LlamaCppPromptStyle(AbstractPromptStyleWithSystemPrompt):
|
||||||
def __init__(self, prompt_style: str, default_system_prompt: str | None = None) -> None:
|
def __init__(
|
||||||
|
self, prompt_style: str, default_system_prompt: str | None = None
|
||||||
|
) -> None:
|
||||||
"""Wrapper for llama_cpp_python defined prompt format.
|
"""Wrapper for llama_cpp_python defined prompt format.
|
||||||
|
|
||||||
:param prompt_style:
|
:param prompt_style:
|
||||||
:param default_system_prompt: Used if no system prompt is given in the messages.
|
:param default_system_prompt: Used if no system prompt is given in the messages.
|
||||||
"""
|
"""
|
||||||
|
assert prompt_style.startswith("llama_cpp.")
|
||||||
default_system_prompt = default_system_prompt or self._DEFAULT_SYSTEM_PROMPT
|
default_system_prompt = default_system_prompt or self._DEFAULT_SYSTEM_PROMPT
|
||||||
super().__init__(default_system_prompt)
|
super().__init__(default_system_prompt)
|
||||||
|
|
||||||
self.prompt_style = prompt_style
|
self.prompt_style = prompt_style[len("llama_cpp.") :]
|
||||||
if self.prompt_style is None:
|
if self.prompt_style is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._llama_cpp_formatter = _get_llama_cpp_chat_format(self.prompt_style)
|
self._llama_cpp_formatter = _get_llama_cpp_chat_format(self.prompt_style)
|
||||||
self.messages_to_prompt = self._messages_to_prompt
|
|
||||||
self.completion_to_prompt = self._completion_to_prompt
|
|
||||||
|
|
||||||
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||||
messages = self._add_missing_system_prompt(messages)
|
messages = self._add_missing_system_prompt(messages)
|
||||||
@ -314,19 +324,28 @@ class LlamaCppPromptStyle(AbstractPromptStyleWithSystemPrompt):
|
|||||||
).prompt
|
).prompt
|
||||||
|
|
||||||
def _completion_to_prompt(self, completion: str) -> str:
|
def _completion_to_prompt(self, completion: str) -> str:
|
||||||
messages = self._add_missing_system_prompt([ChatMessage(content=completion, role=MessageRole.USER)])
|
messages = self._add_missing_system_prompt(
|
||||||
|
[ChatMessage(content=completion, role=MessageRole.USER)]
|
||||||
|
)
|
||||||
return self._llama_cpp_formatter(
|
return self._llama_cpp_formatter(
|
||||||
messages=llama_index_to_llama_cpp_messages(messages)
|
messages=llama_index_to_llama_cpp_messages(messages)
|
||||||
).prompt
|
).prompt
|
||||||
|
|
||||||
|
|
||||||
class TemplatePromptStyle(AbstractPromptStyleWithSystemPrompt):
|
class TemplatePromptStyle(AbstractPromptStyleWithSystemPrompt):
|
||||||
def __init__(self, template_name: str, add_generation_prompt: bool = True,
|
def __init__(
|
||||||
default_system_prompt: str | None = None) -> None:
|
self,
|
||||||
|
template_name: str,
|
||||||
|
template_dir: str | None = None,
|
||||||
|
add_generation_prompt: bool = True,
|
||||||
|
default_system_prompt: str | None = None,
|
||||||
|
) -> None:
|
||||||
"""Prompt format using a Jinja template.
|
"""Prompt format using a Jinja template.
|
||||||
|
|
||||||
:param template_name: the filename of the template to use, must be in
|
:param template_name: the filename of the template to use, must be in
|
||||||
the `./template/` directory.
|
the `./template/` directory.
|
||||||
|
:param template_dir: the directory where the template is located.
|
||||||
|
Defaults to `./template/`.
|
||||||
:param default_system_prompt: Used if no system prompt is
|
:param default_system_prompt: Used if no system prompt is
|
||||||
given in the messages.
|
given in the messages.
|
||||||
"""
|
"""
|
||||||
@ -335,12 +354,18 @@ class TemplatePromptStyle(AbstractPromptStyleWithSystemPrompt):
|
|||||||
|
|
||||||
self._add_generation_prompt = add_generation_prompt
|
self._add_generation_prompt = add_generation_prompt
|
||||||
|
|
||||||
def raise_exception(message):
|
def raise_exception(message: str) -> None:
|
||||||
raise TemplateError(message)
|
raise TemplateError(message)
|
||||||
|
|
||||||
self._jinja_fs_loader = FileSystemLoader(searchpath=THIS_DIRECTORY_RELATIVE / "template")
|
if template_dir is None:
|
||||||
self._jinja_env = ImmutableSandboxedEnvironment(loader=self._jinja_fs_loader, trim_blocks=True,
|
self.template_dir = THIS_DIRECTORY_RELATIVE / "template"
|
||||||
lstrip_blocks=True)
|
else:
|
||||||
|
self.template_dir = Path(template_dir)
|
||||||
|
|
||||||
|
self._jinja_fs_loader = FileSystemLoader(searchpath=self.template_dir)
|
||||||
|
self._jinja_env = ImmutableSandboxedEnvironment(
|
||||||
|
loader=self._jinja_fs_loader, trim_blocks=True, lstrip_blocks=True
|
||||||
|
)
|
||||||
self._jinja_env.globals["raise_exception"] = raise_exception
|
self._jinja_env.globals["raise_exception"] = raise_exception
|
||||||
|
|
||||||
self.template = self._jinja_env.get_template(template_name)
|
self.template = self._jinja_env.get_template(template_name)
|
||||||
@ -368,9 +393,11 @@ class TemplatePromptStyle(AbstractPromptStyleWithSystemPrompt):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _completion_to_prompt(self, completion: str) -> str:
|
def _completion_to_prompt(self, completion: str) -> str:
|
||||||
messages = self._add_missing_system_prompt([
|
messages = self._add_missing_system_prompt(
|
||||||
|
[
|
||||||
ChatMessage(content=completion, role=MessageRole.USER),
|
ChatMessage(content=completion, role=MessageRole.USER),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
return self._messages_to_prompt(messages)
|
return self._messages_to_prompt(messages)
|
||||||
|
|
||||||
|
|
||||||
@ -387,10 +414,9 @@ def get_prompt_style(
|
|||||||
:param prompt_style: The prompt style to use.
|
:param prompt_style: The prompt style to use.
|
||||||
:return: The prompt style to use.
|
:return: The prompt style to use.
|
||||||
"""
|
"""
|
||||||
if prompt_style is None or prompt_style == "default":
|
if prompt_style is None:
|
||||||
return DefaultPromptStyle(**kwargs)
|
return DefaultPromptStyle(**kwargs)
|
||||||
if prompt_style.startswith("llama_cpp."):
|
if prompt_style.startswith("llama_cpp."):
|
||||||
prompt_style = prompt_style[len("llama_cpp."):]
|
|
||||||
return LlamaCppPromptStyle(prompt_style, **kwargs)
|
return LlamaCppPromptStyle(prompt_style, **kwargs)
|
||||||
elif prompt_style == "llama2":
|
elif prompt_style == "llama2":
|
||||||
return LlamaIndexPromptStyle(**kwargs)
|
return LlamaIndexPromptStyle(**kwargs)
|
||||||
|
@ -1,9 +1,14 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from llama_index.llms import ChatMessage, MessageRole
|
from llama_index.llms import ChatMessage, MessageRole
|
||||||
|
|
||||||
from private_gpt.components.llm.prompt.prompt_helper import (
|
from private_gpt.components.llm.prompt.prompt_helper import (
|
||||||
DefaultPromptStyle,
|
DefaultPromptStyle,
|
||||||
|
LlamaCppPromptStyle,
|
||||||
LlamaIndexPromptStyle,
|
LlamaIndexPromptStyle,
|
||||||
|
TemplatePromptStyle,
|
||||||
VigognePromptStyle,
|
VigognePromptStyle,
|
||||||
get_prompt_style,
|
get_prompt_style,
|
||||||
)
|
)
|
||||||
@ -12,13 +17,44 @@ from private_gpt.components.llm.prompt.prompt_helper import (
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("prompt_style", "expected_prompt_style"),
|
("prompt_style", "expected_prompt_style"),
|
||||||
[
|
[
|
||||||
("default", DefaultPromptStyle),
|
(None, DefaultPromptStyle),
|
||||||
("llama2", LlamaIndexPromptStyle),
|
("llama2", LlamaIndexPromptStyle),
|
||||||
("tag", VigognePromptStyle),
|
("vigogne", VigognePromptStyle),
|
||||||
|
("llama_cpp.alpaca", LlamaCppPromptStyle),
|
||||||
|
("llama_cpp.zephyr", LlamaCppPromptStyle),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_get_prompt_style_success(prompt_style, expected_prompt_style):
|
def test_get_prompt_style_success(prompt_style, expected_prompt_style):
|
||||||
assert get_prompt_style(prompt_style) == expected_prompt_style
|
assert type(get_prompt_style(prompt_style)) == expected_prompt_style
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_prompt_style_template_success():
|
||||||
|
jinja_template = "{% for message in messages %}<|{{message['role']}}|>: {{message['content'].strip() + '\\n'}}{% endfor %}<|assistant|>: "
|
||||||
|
with NamedTemporaryFile("w") as tmp_file:
|
||||||
|
path = Path(tmp_file.name)
|
||||||
|
tmp_file.write(jinja_template)
|
||||||
|
tmp_file.flush()
|
||||||
|
tmp_file.seek(0)
|
||||||
|
prompt_style = get_prompt_style(
|
||||||
|
"template", template_name=path.name, template_dir=path.parent
|
||||||
|
)
|
||||||
|
assert type(prompt_style) == TemplatePromptStyle
|
||||||
|
prompt = prompt_style.messages_to_prompt(
|
||||||
|
[
|
||||||
|
ChatMessage(
|
||||||
|
content="You are an AI assistant.", role=MessageRole.SYSTEM
|
||||||
|
),
|
||||||
|
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_prompt = (
|
||||||
|
"<|system|>: You are an AI assistant.\n"
|
||||||
|
"<|user|>: Hello, how are you doing?\n"
|
||||||
|
"<|assistant|>: "
|
||||||
|
)
|
||||||
|
|
||||||
|
assert prompt == expected_prompt
|
||||||
|
|
||||||
|
|
||||||
def test_get_prompt_style_failure():
|
def test_get_prompt_style_failure():
|
||||||
|
Loading…
Reference in New Issue
Block a user