diff --git a/private_gpt/components/llm/llm_component.py b/private_gpt/components/llm/llm_component.py index e03ac266..8c2b3ee2 100644 --- a/private_gpt/components/llm/llm_component.py +++ b/private_gpt/components/llm/llm_component.py @@ -22,12 +22,14 @@ class LLMComponent: case "local": from llama_index.llms import LlamaCPP - from private_gpt.components.llm.prompt.prompt_helper import get_prompt_style + from private_gpt.components.llm.prompt.prompt_helper import ( + get_prompt_style, + ) prompt_style = get_prompt_style( prompt_style=settings.local.prompt_style, template_name=settings.local.template_name, - default_system_prompt=settings.local.default_system_prompt + default_system_prompt=settings.local.default_system_prompt, ) self.llm = LlamaCPP( diff --git a/private_gpt/components/llm/prompt/prompt_helper.py b/private_gpt/components/llm/prompt/prompt_helper.py index ace96c40..0e5ff2e9 100644 --- a/private_gpt/components/llm/prompt/prompt_helper.py +++ b/private_gpt/components/llm/prompt/prompt_helper.py @@ -1,5 +1,4 @@ -""" -Helper to get your llama_index messages correctly serialized into a prompt. +"""Helper to get your llama_index messages correctly serialized into a prompt. This set of classes and functions is used to format a series of llama_index ChatMessage into a prompt (a unique string) that will be passed @@ -29,14 +28,13 @@ import abc import logging from collections.abc import Sequence from pathlib import Path -from typing import Any, cast +from typing import Any from jinja2 import FileSystemLoader from jinja2.exceptions import TemplateError from jinja2.sandbox import ImmutableSandboxedEnvironment -from llama_cpp import llama_types, Llama -from llama_cpp import llama_chat_format -from llama_index.llms import ChatMessage, MessageRole, LlamaCPP +from llama_cpp import llama_chat_format, llama_types +from llama_index.llms import ChatMessage, MessageRole from llama_index.llms.llama_utils import ( DEFAULT_SYSTEM_PROMPT, completion_to_prompt, @@ -50,7 +48,7 @@ logger = logging.getLogger(__name__) THIS_DIRECTORY_RELATIVE = Path(__file__).parent.relative_to(PROJECT_ROOT_PATH) -_LLAMA_CPP_PYTHON_CHAT_FORMAT = { +_LLAMA_CPP_PYTHON_CHAT_FORMAT: dict[str, llama_chat_format.ChatFormatter] = { "llama-2": llama_chat_format.format_llama2, "alpaca": llama_chat_format.format_alpaca, "vicuna": llama_chat_format.format, @@ -80,6 +78,7 @@ def llama_index_to_llama_cpp_messages( list of llama_cpp ChatCompletionRequestMessage. """ llama_cpp_messages: list[llama_types.ChatCompletionRequestMessage] = [] + l_msg: llama_types.ChatCompletionRequestMessage for msg in messages: if msg.role == MessageRole.SYSTEM: l_msg = llama_types.ChatCompletionRequestSystemMessage( @@ -112,10 +111,11 @@ def llama_index_to_llama_cpp_messages( def _get_llama_cpp_chat_format(name: str) -> llama_chat_format.ChatFormatter: + logger.debug("Getting llama_cpp_python prompt_format='%s'", name) try: return _LLAMA_CPP_PYTHON_CHAT_FORMAT[name] - except KeyError: - raise ValueError(f"Unknown llama_cpp_python prompt style '{name}'") + except KeyError as err: + raise ValueError(f"Unknown llama_cpp_python prompt style '{name}'") from err class AbstractPromptStyle(abc.ABC): @@ -161,18 +161,18 @@ class AbstractPromptStyle(abc.ABC): logger.debug("Got for completion='%s' the prompt='%s'", completion, prompt) return prompt - def improve_prompt_format(self, llm: LlamaCPP) -> None: - """Improve the prompt format of the given LLM. - - Use the given metadata in the LLM to improve the prompt format. - """ - # FIXME: below we are getting IDs (1,2,13) from llama.cpp, and not actual strings - llama_cpp_llm = cast(Llama, llm._model) - self.bos_token = llama_cpp_llm.token_bos() - self.eos_token = llama_cpp_llm.token_eos() - self.nl_token = llama_cpp_llm.token_nl() - print([self.bos_token, self.eos_token, self.nl_token]) - # (1,2,13) are the IDs of the tokens + # def improve_prompt_format(self, llm: LlamaCPP) -> None: + # """Improve the prompt format of the given LLM. + # + # Use the given metadata in the LLM to improve the prompt format. + # """ + # # FIXME: we are getting IDs (1,2,13) from llama.cpp, and not actual strings + # llama_cpp_llm = cast(Llama, llm._model) + # self.bos_token = llama_cpp_llm.token_bos() + # self.eos_token = llama_cpp_llm.token_eos() + # self.nl_token = llama_cpp_llm.token_nl() + # print([self.bos_token, self.eos_token, self.nl_token]) + # # (1,2,13) are the IDs of the tokens class AbstractPromptStyleWithSystemPrompt(AbstractPromptStyle, abc.ABC): @@ -183,14 +183,18 @@ class AbstractPromptStyleWithSystemPrompt(AbstractPromptStyle, abc.ABC): logger.debug("Got default_system_prompt='%s'", default_system_prompt) self.default_system_prompt = default_system_prompt - def _add_missing_system_prompt(self, messages: Sequence[ChatMessage]) -> list[ChatMessage]: + def _add_missing_system_prompt( + self, messages: Sequence[ChatMessage] + ) -> Sequence[ChatMessage]: if messages[0].role != MessageRole.SYSTEM: logger.debug( "Adding system_promt='%s' to the given messages as there are none given in the session", self.default_system_prompt, ) messages = [ - ChatMessage(content=self.default_system_prompt, role=MessageRole.SYSTEM), + ChatMessage( + content=self.default_system_prompt, role=MessageRole.SYSTEM + ), *messages, ] return messages @@ -256,7 +260,11 @@ class VigognePromptStyle(AbstractPromptStyleWithSystemPrompt): FIXME: should we add surrounding `` and `` tags, like in llama2? """ - def __init__(self, default_system_prompt: str | None = None, add_generation_prompt: bool = True) -> None: + def __init__( + self, + default_system_prompt: str | None = None, + add_generation_prompt: bool = True, + ) -> None: # We have to define a default system prompt here as the LLM will not # use the default llama_utils functions. default_system_prompt = default_system_prompt or self._DEFAULT_SYSTEM_PROMPT @@ -272,7 +280,7 @@ class VigognePromptStyle(AbstractPromptStyleWithSystemPrompt): messages = [ChatMessage(content=completion, role=MessageRole.USER)] return self._format_messages_to_prompt(messages) - def _format_messages_to_prompt(self, messages: list[ChatMessage]) -> str: + def _format_messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: # TODO add BOS and EOS TOKEN !!!!! (c.f. jinja template) """Format message to prompt with `<|ROLE|>: MSG` style.""" assert messages[0].role == MessageRole.SYSTEM @@ -291,21 +299,23 @@ class VigognePromptStyle(AbstractPromptStyleWithSystemPrompt): class LlamaCppPromptStyle(AbstractPromptStyleWithSystemPrompt): - def __init__(self, prompt_style: str, default_system_prompt: str | None = None) -> None: + def __init__( + self, prompt_style: str, default_system_prompt: str | None = None + ) -> None: """Wrapper for llama_cpp_python defined prompt format. + :param prompt_style: :param default_system_prompt: Used if no system prompt is given in the messages. """ + assert prompt_style.startswith("llama_cpp.") default_system_prompt = default_system_prompt or self._DEFAULT_SYSTEM_PROMPT super().__init__(default_system_prompt) - self.prompt_style = prompt_style + self.prompt_style = prompt_style[len("llama_cpp.") :] if self.prompt_style is None: return self._llama_cpp_formatter = _get_llama_cpp_chat_format(self.prompt_style) - self.messages_to_prompt = self._messages_to_prompt - self.completion_to_prompt = self._completion_to_prompt def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: messages = self._add_missing_system_prompt(messages) @@ -314,19 +324,28 @@ class LlamaCppPromptStyle(AbstractPromptStyleWithSystemPrompt): ).prompt def _completion_to_prompt(self, completion: str) -> str: - messages = self._add_missing_system_prompt([ChatMessage(content=completion, role=MessageRole.USER)]) + messages = self._add_missing_system_prompt( + [ChatMessage(content=completion, role=MessageRole.USER)] + ) return self._llama_cpp_formatter( messages=llama_index_to_llama_cpp_messages(messages) ).prompt class TemplatePromptStyle(AbstractPromptStyleWithSystemPrompt): - def __init__(self, template_name: str, add_generation_prompt: bool = True, - default_system_prompt: str | None = None) -> None: + def __init__( + self, + template_name: str, + template_dir: str | None = None, + add_generation_prompt: bool = True, + default_system_prompt: str | None = None, + ) -> None: """Prompt format using a Jinja template. :param template_name: the filename of the template to use, must be in the `./template/` directory. + :param template_dir: the directory where the template is located. + Defaults to `./template/`. :param default_system_prompt: Used if no system prompt is given in the messages. """ @@ -335,12 +354,18 @@ class TemplatePromptStyle(AbstractPromptStyleWithSystemPrompt): self._add_generation_prompt = add_generation_prompt - def raise_exception(message): + def raise_exception(message: str) -> None: raise TemplateError(message) - self._jinja_fs_loader = FileSystemLoader(searchpath=THIS_DIRECTORY_RELATIVE / "template") - self._jinja_env = ImmutableSandboxedEnvironment(loader=self._jinja_fs_loader, trim_blocks=True, - lstrip_blocks=True) + if template_dir is None: + self.template_dir = THIS_DIRECTORY_RELATIVE / "template" + else: + self.template_dir = Path(template_dir) + + self._jinja_fs_loader = FileSystemLoader(searchpath=self.template_dir) + self._jinja_env = ImmutableSandboxedEnvironment( + loader=self._jinja_fs_loader, trim_blocks=True, lstrip_blocks=True + ) self._jinja_env.globals["raise_exception"] = raise_exception self.template = self._jinja_env.get_template(template_name) @@ -368,9 +393,11 @@ class TemplatePromptStyle(AbstractPromptStyleWithSystemPrompt): ) def _completion_to_prompt(self, completion: str) -> str: - messages = self._add_missing_system_prompt([ - ChatMessage(content=completion, role=MessageRole.USER), - ]) + messages = self._add_missing_system_prompt( + [ + ChatMessage(content=completion, role=MessageRole.USER), + ] + ) return self._messages_to_prompt(messages) @@ -380,17 +407,16 @@ class TemplatePromptStyle(AbstractPromptStyleWithSystemPrompt): # Pass all the arguments at once def get_prompt_style( prompt_style: str | None, - **kwargs: Any, + **kwargs: Any, ) -> AbstractPromptStyle: """Get the prompt style to use from the given string. :param prompt_style: The prompt style to use. :return: The prompt style to use. """ - if prompt_style is None or prompt_style == "default": + if prompt_style is None: return DefaultPromptStyle(**kwargs) if prompt_style.startswith("llama_cpp."): - prompt_style = prompt_style[len("llama_cpp."):] return LlamaCppPromptStyle(prompt_style, **kwargs) elif prompt_style == "llama2": return LlamaIndexPromptStyle(**kwargs) diff --git a/tests/test_prompt_helper.py b/tests/test_prompt_helper.py index c5594691..e2270209 100644 --- a/tests/test_prompt_helper.py +++ b/tests/test_prompt_helper.py @@ -1,9 +1,14 @@ +from pathlib import Path +from tempfile import NamedTemporaryFile + import pytest from llama_index.llms import ChatMessage, MessageRole from private_gpt.components.llm.prompt.prompt_helper import ( DefaultPromptStyle, + LlamaCppPromptStyle, LlamaIndexPromptStyle, + TemplatePromptStyle, VigognePromptStyle, get_prompt_style, ) @@ -12,13 +17,44 @@ from private_gpt.components.llm.prompt.prompt_helper import ( @pytest.mark.parametrize( ("prompt_style", "expected_prompt_style"), [ - ("default", DefaultPromptStyle), + (None, DefaultPromptStyle), ("llama2", LlamaIndexPromptStyle), - ("tag", VigognePromptStyle), + ("vigogne", VigognePromptStyle), + ("llama_cpp.alpaca", LlamaCppPromptStyle), + ("llama_cpp.zephyr", LlamaCppPromptStyle), ], ) def test_get_prompt_style_success(prompt_style, expected_prompt_style): - assert get_prompt_style(prompt_style) == expected_prompt_style + assert type(get_prompt_style(prompt_style)) == expected_prompt_style + + +def test_get_prompt_style_template_success(): + jinja_template = "{% for message in messages %}<|{{message['role']}}|>: {{message['content'].strip() + '\\n'}}{% endfor %}<|assistant|>: " + with NamedTemporaryFile("w") as tmp_file: + path = Path(tmp_file.name) + tmp_file.write(jinja_template) + tmp_file.flush() + tmp_file.seek(0) + prompt_style = get_prompt_style( + "template", template_name=path.name, template_dir=path.parent + ) + assert type(prompt_style) == TemplatePromptStyle + prompt = prompt_style.messages_to_prompt( + [ + ChatMessage( + content="You are an AI assistant.", role=MessageRole.SYSTEM + ), + ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), + ] + ) + + expected_prompt = ( + "<|system|>: You are an AI assistant.\n" + "<|user|>: Hello, how are you doing?\n" + "<|assistant|>: " + ) + + assert prompt == expected_prompt def test_get_prompt_style_failure():