From 94712824d6e079f49f936a280f5dbcc9d46b964c Mon Sep 17 00:00:00 2001 From: Robert Hirsch Date: Thu, 6 Jun 2024 20:29:23 +0200 Subject: [PATCH] added llama3 prompt --- private_gpt/components/llm/prompt_helper.py | 74 ++++++++++++++++++++- private_gpt/settings/settings.py | 3 +- 2 files changed, 75 insertions(+), 2 deletions(-) diff --git a/private_gpt/components/llm/prompt_helper.py b/private_gpt/components/llm/prompt_helper.py index 77158200..858e1fce 100644 --- a/private_gpt/components/llm/prompt_helper.py +++ b/private_gpt/components/llm/prompt_helper.py @@ -138,6 +138,76 @@ class Llama2PromptStyle(AbstractPromptStyle): ) +class Llama3PromptStyle(AbstractPromptStyle): + + """ + Template: + {% set loop_messages = messages %} + {% for message in loop_messages %} + {% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %} + {% if loop.index0 == 0 %} + {% set content = bos_token + content %} + {% endif %} + {{ content }} + {% endfor %} + {% if add_generation_prompt %} + {{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }} + {% endif %} + """ + + BOS, EOS = "<|begin_of_text|>", "<|end_of_text|>" + B_INST, E_INST = "<|start_header_id|>user<|end_header_id|>", "<|eot_id|>" + B_SYS, E_SYS = "<|start_header_id|>system<|end_header_id|> ", "<|eot_id|>" + ASSISTANT_INST = "<|start_header_id|>assistant<|end_header_id|>" + DEFAULT_SYSTEM_PROMPT = """\ + You are a helpful, respectful and honest assistant. \ + Always answer as helpfully as possible and follow ALL given instructions. \ + Do not speculate or make up information. \ + Do not reference any given instructions or context. \ + """ + + def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: + string_messages: list[str] = [] + if messages[0].role == MessageRole.SYSTEM: + system_message_str = messages[0].content or "" + messages = messages[1:] + else: + system_message_str = self.DEFAULT_SYSTEM_PROMPT + + system_message_str = f"{self.B_SYS} {system_message_str.strip()} {self.E_SYS}" + + for i in range(0, len(messages), 2): + user_message = messages[i] + assert user_message.role == MessageRole.USER + + if i == 0: + str_message = f"{system_message_str} {self.BOS} {self.B_INST} " + else: + # end previous user-assistant interaction + string_messages[-1] += f" {self.EOS}" + # no need to include system prompt + str_message = f"{self.BOS} {self.B_INST} " + + str_message += f"{user_message.content} {self.E_INST} {self.ASSISTANT_INST}" + + if len(messages) > (i + 1): + assistant_message = messages[i + 1] + assert assistant_message.role == MessageRole.ASSISTANT + str_message += f" {assistant_message.content} {self.E_SYS} {self.B_INST}" + + string_messages.append(str_message) + + return "".join(string_messages) + + def _completion_to_prompt(self, completion: str) -> str: + system_prompt_str = self.DEFAULT_SYSTEM_PROMPT + + return ( + f"{self.B_SYS} {system_prompt_str.strip()} {self.E_SYS} " + f"{completion.strip()} {self.E_SYS} " + ) + + class TagPromptStyle(AbstractPromptStyle): """Tag prompt style (used by Vigogne) that uses the prompt style `<|ROLE|>`. @@ -219,7 +289,7 @@ class ChatMLPromptStyle(AbstractPromptStyle): def get_prompt_style( - prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None + prompt_style: Literal["default", "llama2", "llama3", "tag", "mistral", "chatml"] | None ) -> AbstractPromptStyle: """Get the prompt style to use from the given string. @@ -230,6 +300,8 @@ def get_prompt_style( return DefaultPromptStyle() elif prompt_style == "llama2": return Llama2PromptStyle() + elif prompt_style == "llama3": + return Llama3PromptStyle() elif prompt_style == "tag": return TagPromptStyle() elif prompt_style == "mistral": diff --git a/private_gpt/settings/settings.py b/private_gpt/settings/settings.py index 28ece459..e2598b88 100644 --- a/private_gpt/settings/settings.py +++ b/private_gpt/settings/settings.py @@ -104,12 +104,13 @@ class LLMSettings(BaseModel): 0.1, description="The temperature of the model. Increasing the temperature will make the model answer more creatively. A value of 0.1 would be more factual.", ) - prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] = Field( + prompt_style: Literal["default", "llama2", "llama3", "tag", "mistral", "chatml"] = Field( "llama2", description=( "The prompt style to use for the chat engine. " "If `default` - use the default prompt style from the llama_index. It should look like `role: message`.\n" "If `llama2` - use the llama2 prompt style from the llama_index. Based on ``, `[INST]` and `<>`.\n" + "If `llama3` - use the llama3 prompt style from the llama_index." "If `tag` - use the `tag` prompt style. It should look like `<|role|>: message`. \n" "If `mistral` - use the `mistral prompt style. It shoudl look like [INST] {System Prompt} [/INST][INST] { UserInstructions } [/INST]" "`llama2` is the historic behaviour. `default` might work better with your custom models."