Fix typing, linting and add tests

This commit is contained in:
Louis 2023-12-03 16:31:20 +01:00
parent 76faffb269
commit 5bc5054000
3 changed files with 111 additions and 47 deletions

View File

@ -22,12 +22,14 @@ class LLMComponent:
case "local":
from llama_index.llms import LlamaCPP
from private_gpt.components.llm.prompt.prompt_helper import get_prompt_style
from private_gpt.components.llm.prompt.prompt_helper import (
get_prompt_style,
)
prompt_style = get_prompt_style(
prompt_style=settings.local.prompt_style,
template_name=settings.local.template_name,
default_system_prompt=settings.local.default_system_prompt
default_system_prompt=settings.local.default_system_prompt,
)
self.llm = LlamaCPP(

View File

@ -1,5 +1,4 @@
"""
Helper to get your llama_index messages correctly serialized into a prompt.
"""Helper to get your llama_index messages correctly serialized into a prompt.
This set of classes and functions is used to format a series of
llama_index ChatMessage into a prompt (a unique string) that will be passed
@ -29,14 +28,13 @@ import abc
import logging
from collections.abc import Sequence
from pathlib import Path
from typing import Any, cast
from typing import Any
from jinja2 import FileSystemLoader
from jinja2.exceptions import TemplateError
from jinja2.sandbox import ImmutableSandboxedEnvironment
from llama_cpp import llama_types, Llama
from llama_cpp import llama_chat_format
from llama_index.llms import ChatMessage, MessageRole, LlamaCPP
from llama_cpp import llama_chat_format, llama_types
from llama_index.llms import ChatMessage, MessageRole
from llama_index.llms.llama_utils import (
DEFAULT_SYSTEM_PROMPT,
completion_to_prompt,
@ -50,7 +48,7 @@ logger = logging.getLogger(__name__)
THIS_DIRECTORY_RELATIVE = Path(__file__).parent.relative_to(PROJECT_ROOT_PATH)
_LLAMA_CPP_PYTHON_CHAT_FORMAT = {
_LLAMA_CPP_PYTHON_CHAT_FORMAT: dict[str, llama_chat_format.ChatFormatter] = {
"llama-2": llama_chat_format.format_llama2,
"alpaca": llama_chat_format.format_alpaca,
"vicuna": llama_chat_format.format,
@ -80,6 +78,7 @@ def llama_index_to_llama_cpp_messages(
list of llama_cpp ChatCompletionRequestMessage.
"""
llama_cpp_messages: list[llama_types.ChatCompletionRequestMessage] = []
l_msg: llama_types.ChatCompletionRequestMessage
for msg in messages:
if msg.role == MessageRole.SYSTEM:
l_msg = llama_types.ChatCompletionRequestSystemMessage(
@ -112,10 +111,11 @@ def llama_index_to_llama_cpp_messages(
def _get_llama_cpp_chat_format(name: str) -> llama_chat_format.ChatFormatter:
logger.debug("Getting llama_cpp_python prompt_format='%s'", name)
try:
return _LLAMA_CPP_PYTHON_CHAT_FORMAT[name]
except KeyError:
raise ValueError(f"Unknown llama_cpp_python prompt style '{name}'")
except KeyError as err:
raise ValueError(f"Unknown llama_cpp_python prompt style '{name}'") from err
class AbstractPromptStyle(abc.ABC):
@ -161,18 +161,18 @@ class AbstractPromptStyle(abc.ABC):
logger.debug("Got for completion='%s' the prompt='%s'", completion, prompt)
return prompt
def improve_prompt_format(self, llm: LlamaCPP) -> None:
"""Improve the prompt format of the given LLM.
Use the given metadata in the LLM to improve the prompt format.
"""
# FIXME: below we are getting IDs (1,2,13) from llama.cpp, and not actual strings
llama_cpp_llm = cast(Llama, llm._model)
self.bos_token = llama_cpp_llm.token_bos()
self.eos_token = llama_cpp_llm.token_eos()
self.nl_token = llama_cpp_llm.token_nl()
print([self.bos_token, self.eos_token, self.nl_token])
# (1,2,13) are the IDs of the tokens
# def improve_prompt_format(self, llm: LlamaCPP) -> None:
# """Improve the prompt format of the given LLM.
#
# Use the given metadata in the LLM to improve the prompt format.
# """
# # FIXME: we are getting IDs (1,2,13) from llama.cpp, and not actual strings
# llama_cpp_llm = cast(Llama, llm._model)
# self.bos_token = llama_cpp_llm.token_bos()
# self.eos_token = llama_cpp_llm.token_eos()
# self.nl_token = llama_cpp_llm.token_nl()
# print([self.bos_token, self.eos_token, self.nl_token])
# # (1,2,13) are the IDs of the tokens
class AbstractPromptStyleWithSystemPrompt(AbstractPromptStyle, abc.ABC):
@ -183,14 +183,18 @@ class AbstractPromptStyleWithSystemPrompt(AbstractPromptStyle, abc.ABC):
logger.debug("Got default_system_prompt='%s'", default_system_prompt)
self.default_system_prompt = default_system_prompt
def _add_missing_system_prompt(self, messages: Sequence[ChatMessage]) -> list[ChatMessage]:
def _add_missing_system_prompt(
self, messages: Sequence[ChatMessage]
) -> Sequence[ChatMessage]:
if messages[0].role != MessageRole.SYSTEM:
logger.debug(
"Adding system_promt='%s' to the given messages as there are none given in the session",
self.default_system_prompt,
)
messages = [
ChatMessage(content=self.default_system_prompt, role=MessageRole.SYSTEM),
ChatMessage(
content=self.default_system_prompt, role=MessageRole.SYSTEM
),
*messages,
]
return messages
@ -256,7 +260,11 @@ class VigognePromptStyle(AbstractPromptStyleWithSystemPrompt):
FIXME: should we add surrounding `<s>` and `</s>` tags, like in llama2?
"""
def __init__(self, default_system_prompt: str | None = None, add_generation_prompt: bool = True) -> None:
def __init__(
self,
default_system_prompt: str | None = None,
add_generation_prompt: bool = True,
) -> None:
# We have to define a default system prompt here as the LLM will not
# use the default llama_utils functions.
default_system_prompt = default_system_prompt or self._DEFAULT_SYSTEM_PROMPT
@ -272,7 +280,7 @@ class VigognePromptStyle(AbstractPromptStyleWithSystemPrompt):
messages = [ChatMessage(content=completion, role=MessageRole.USER)]
return self._format_messages_to_prompt(messages)
def _format_messages_to_prompt(self, messages: list[ChatMessage]) -> str:
def _format_messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
# TODO add BOS and EOS TOKEN !!!!! (c.f. jinja template)
"""Format message to prompt with `<|ROLE|>: MSG` style."""
assert messages[0].role == MessageRole.SYSTEM
@ -291,21 +299,23 @@ class VigognePromptStyle(AbstractPromptStyleWithSystemPrompt):
class LlamaCppPromptStyle(AbstractPromptStyleWithSystemPrompt):
def __init__(self, prompt_style: str, default_system_prompt: str | None = None) -> None:
def __init__(
self, prompt_style: str, default_system_prompt: str | None = None
) -> None:
"""Wrapper for llama_cpp_python defined prompt format.
:param prompt_style:
:param default_system_prompt: Used if no system prompt is given in the messages.
"""
assert prompt_style.startswith("llama_cpp.")
default_system_prompt = default_system_prompt or self._DEFAULT_SYSTEM_PROMPT
super().__init__(default_system_prompt)
self.prompt_style = prompt_style
self.prompt_style = prompt_style[len("llama_cpp.") :]
if self.prompt_style is None:
return
self._llama_cpp_formatter = _get_llama_cpp_chat_format(self.prompt_style)
self.messages_to_prompt = self._messages_to_prompt
self.completion_to_prompt = self._completion_to_prompt
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
messages = self._add_missing_system_prompt(messages)
@ -314,19 +324,28 @@ class LlamaCppPromptStyle(AbstractPromptStyleWithSystemPrompt):
).prompt
def _completion_to_prompt(self, completion: str) -> str:
messages = self._add_missing_system_prompt([ChatMessage(content=completion, role=MessageRole.USER)])
messages = self._add_missing_system_prompt(
[ChatMessage(content=completion, role=MessageRole.USER)]
)
return self._llama_cpp_formatter(
messages=llama_index_to_llama_cpp_messages(messages)
).prompt
class TemplatePromptStyle(AbstractPromptStyleWithSystemPrompt):
def __init__(self, template_name: str, add_generation_prompt: bool = True,
default_system_prompt: str | None = None) -> None:
def __init__(
self,
template_name: str,
template_dir: str | None = None,
add_generation_prompt: bool = True,
default_system_prompt: str | None = None,
) -> None:
"""Prompt format using a Jinja template.
:param template_name: the filename of the template to use, must be in
the `./template/` directory.
:param template_dir: the directory where the template is located.
Defaults to `./template/`.
:param default_system_prompt: Used if no system prompt is
given in the messages.
"""
@ -335,12 +354,18 @@ class TemplatePromptStyle(AbstractPromptStyleWithSystemPrompt):
self._add_generation_prompt = add_generation_prompt
def raise_exception(message):
def raise_exception(message: str) -> None:
raise TemplateError(message)
self._jinja_fs_loader = FileSystemLoader(searchpath=THIS_DIRECTORY_RELATIVE / "template")
self._jinja_env = ImmutableSandboxedEnvironment(loader=self._jinja_fs_loader, trim_blocks=True,
lstrip_blocks=True)
if template_dir is None:
self.template_dir = THIS_DIRECTORY_RELATIVE / "template"
else:
self.template_dir = Path(template_dir)
self._jinja_fs_loader = FileSystemLoader(searchpath=self.template_dir)
self._jinja_env = ImmutableSandboxedEnvironment(
loader=self._jinja_fs_loader, trim_blocks=True, lstrip_blocks=True
)
self._jinja_env.globals["raise_exception"] = raise_exception
self.template = self._jinja_env.get_template(template_name)
@ -368,9 +393,11 @@ class TemplatePromptStyle(AbstractPromptStyleWithSystemPrompt):
)
def _completion_to_prompt(self, completion: str) -> str:
messages = self._add_missing_system_prompt([
ChatMessage(content=completion, role=MessageRole.USER),
])
messages = self._add_missing_system_prompt(
[
ChatMessage(content=completion, role=MessageRole.USER),
]
)
return self._messages_to_prompt(messages)
@ -380,17 +407,16 @@ class TemplatePromptStyle(AbstractPromptStyleWithSystemPrompt):
# Pass all the arguments at once
def get_prompt_style(
prompt_style: str | None,
**kwargs: Any,
**kwargs: Any,
) -> AbstractPromptStyle:
"""Get the prompt style to use from the given string.
:param prompt_style: The prompt style to use.
:return: The prompt style to use.
"""
if prompt_style is None or prompt_style == "default":
if prompt_style is None:
return DefaultPromptStyle(**kwargs)
if prompt_style.startswith("llama_cpp."):
prompt_style = prompt_style[len("llama_cpp."):]
return LlamaCppPromptStyle(prompt_style, **kwargs)
elif prompt_style == "llama2":
return LlamaIndexPromptStyle(**kwargs)

View File

@ -1,9 +1,14 @@
from pathlib import Path
from tempfile import NamedTemporaryFile
import pytest
from llama_index.llms import ChatMessage, MessageRole
from private_gpt.components.llm.prompt.prompt_helper import (
DefaultPromptStyle,
LlamaCppPromptStyle,
LlamaIndexPromptStyle,
TemplatePromptStyle,
VigognePromptStyle,
get_prompt_style,
)
@ -12,13 +17,44 @@ from private_gpt.components.llm.prompt.prompt_helper import (
@pytest.mark.parametrize(
("prompt_style", "expected_prompt_style"),
[
("default", DefaultPromptStyle),
(None, DefaultPromptStyle),
("llama2", LlamaIndexPromptStyle),
("tag", VigognePromptStyle),
("vigogne", VigognePromptStyle),
("llama_cpp.alpaca", LlamaCppPromptStyle),
("llama_cpp.zephyr", LlamaCppPromptStyle),
],
)
def test_get_prompt_style_success(prompt_style, expected_prompt_style):
assert get_prompt_style(prompt_style) == expected_prompt_style
assert type(get_prompt_style(prompt_style)) == expected_prompt_style
def test_get_prompt_style_template_success():
jinja_template = "{% for message in messages %}<|{{message['role']}}|>: {{message['content'].strip() + '\\n'}}{% endfor %}<|assistant|>: "
with NamedTemporaryFile("w") as tmp_file:
path = Path(tmp_file.name)
tmp_file.write(jinja_template)
tmp_file.flush()
tmp_file.seek(0)
prompt_style = get_prompt_style(
"template", template_name=path.name, template_dir=path.parent
)
assert type(prompt_style) == TemplatePromptStyle
prompt = prompt_style.messages_to_prompt(
[
ChatMessage(
content="You are an AI assistant.", role=MessageRole.SYSTEM
),
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
]
)
expected_prompt = (
"<|system|>: You are an AI assistant.\n"
"<|user|>: Hello, how are you doing?\n"
"<|assistant|>: "
)
assert prompt == expected_prompt
def test_get_prompt_style_failure():