mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-06-28 16:26:56 +00:00
added llama3 prompt (#1962)
* added llama3 prompt * more fixes to pass tests; changed type VectorStore -> BasePydanticVectorStore, see https://github.com/run-llama/llama_index/blob/main/CHANGELOG.md#2024-05-14 * fix: new llama3 prompt --------- Co-authored-by: Javier Martinez <javiermartinezalvarez98@gmail.com>
This commit is contained in:
parent
d4375d078f
commit
d080969407
@ -138,6 +138,73 @@ class Llama2PromptStyle(AbstractPromptStyle):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Llama3PromptStyle(AbstractPromptStyle):
|
||||||
|
r"""Template for Meta's Llama 3.1.
|
||||||
|
|
||||||
|
The format follows this structure:
|
||||||
|
<|begin_of_text|>
|
||||||
|
<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
[System message content]<|eot_id|>
|
||||||
|
<|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
[User message content]<|eot_id|>
|
||||||
|
<|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
[Assistant message content]<|eot_id|>
|
||||||
|
...
|
||||||
|
(Repeat for each message, including possible 'ipython' role)
|
||||||
|
"""
|
||||||
|
|
||||||
|
BOS, EOS = "<|begin_of_text|>", "<|end_of_text|>"
|
||||||
|
B_INST, E_INST = "<|start_header_id|>", "<|end_header_id|>"
|
||||||
|
EOT = "<|eot_id|>"
|
||||||
|
B_SYS, E_SYS = "<|start_header_id|>system<|end_header_id|>", "<|eot_id|>"
|
||||||
|
ASSISTANT_INST = "<|start_header_id|>assistant<|end_header_id|>"
|
||||||
|
DEFAULT_SYSTEM_PROMPT = """\
|
||||||
|
You are a helpful, respectful and honest assistant. \
|
||||||
|
Always answer as helpfully as possible and follow ALL given instructions. \
|
||||||
|
Do not speculate or make up information. \
|
||||||
|
Do not reference any given instructions or context. \
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||||
|
prompt = self.BOS
|
||||||
|
has_system_message = False
|
||||||
|
|
||||||
|
for i, message in enumerate(messages):
|
||||||
|
if not message or message.content is None:
|
||||||
|
continue
|
||||||
|
if message.role == MessageRole.SYSTEM:
|
||||||
|
prompt += f"{self.B_SYS}\n\n{message.content.strip()}{self.E_SYS}"
|
||||||
|
has_system_message = True
|
||||||
|
else:
|
||||||
|
role_header = f"{self.B_INST}{message.role.value}{self.E_INST}"
|
||||||
|
prompt += f"{role_header}\n\n{message.content.strip()}{self.EOT}"
|
||||||
|
|
||||||
|
# Add assistant header if the last message is not from the assistant
|
||||||
|
if i == len(messages) - 1 and message.role != MessageRole.ASSISTANT:
|
||||||
|
prompt += f"{self.ASSISTANT_INST}\n\n"
|
||||||
|
|
||||||
|
# Add default system prompt if no system message was provided
|
||||||
|
if not has_system_message:
|
||||||
|
prompt = (
|
||||||
|
f"{self.BOS}{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.E_SYS}"
|
||||||
|
+ prompt[len(self.BOS) :]
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Implement tool handling logic
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def _completion_to_prompt(self, completion: str) -> str:
|
||||||
|
return (
|
||||||
|
f"{self.BOS}{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.E_SYS}"
|
||||||
|
f"{self.B_INST}user{self.E_INST}\n\n{completion.strip()}{self.EOT}"
|
||||||
|
f"{self.ASSISTANT_INST}\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TagPromptStyle(AbstractPromptStyle):
|
class TagPromptStyle(AbstractPromptStyle):
|
||||||
"""Tag prompt style (used by Vigogne) that uses the prompt style `<|ROLE|>`.
|
"""Tag prompt style (used by Vigogne) that uses the prompt style `<|ROLE|>`.
|
||||||
|
|
||||||
@ -219,7 +286,8 @@ class ChatMLPromptStyle(AbstractPromptStyle):
|
|||||||
|
|
||||||
|
|
||||||
def get_prompt_style(
|
def get_prompt_style(
|
||||||
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None
|
prompt_style: Literal["default", "llama2", "llama3", "tag", "mistral", "chatml"]
|
||||||
|
| None
|
||||||
) -> AbstractPromptStyle:
|
) -> AbstractPromptStyle:
|
||||||
"""Get the prompt style to use from the given string.
|
"""Get the prompt style to use from the given string.
|
||||||
|
|
||||||
@ -230,6 +298,8 @@ def get_prompt_style(
|
|||||||
return DefaultPromptStyle()
|
return DefaultPromptStyle()
|
||||||
elif prompt_style == "llama2":
|
elif prompt_style == "llama2":
|
||||||
return Llama2PromptStyle()
|
return Llama2PromptStyle()
|
||||||
|
elif prompt_style == "llama3":
|
||||||
|
return Llama3PromptStyle()
|
||||||
elif prompt_style == "tag":
|
elif prompt_style == "tag":
|
||||||
return TagPromptStyle()
|
return TagPromptStyle()
|
||||||
elif prompt_style == "mistral":
|
elif prompt_style == "mistral":
|
||||||
|
@ -111,12 +111,15 @@ class LLMSettings(BaseModel):
|
|||||||
0.1,
|
0.1,
|
||||||
description="The temperature of the model. Increasing the temperature will make the model answer more creatively. A value of 0.1 would be more factual.",
|
description="The temperature of the model. Increasing the temperature will make the model answer more creatively. A value of 0.1 would be more factual.",
|
||||||
)
|
)
|
||||||
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] = Field(
|
prompt_style: Literal[
|
||||||
|
"default", "llama2", "llama3", "tag", "mistral", "chatml"
|
||||||
|
] = Field(
|
||||||
"llama2",
|
"llama2",
|
||||||
description=(
|
description=(
|
||||||
"The prompt style to use for the chat engine. "
|
"The prompt style to use for the chat engine. "
|
||||||
"If `default` - use the default prompt style from the llama_index. It should look like `role: message`.\n"
|
"If `default` - use the default prompt style from the llama_index. It should look like `role: message`.\n"
|
||||||
"If `llama2` - use the llama2 prompt style from the llama_index. Based on `<s>`, `[INST]` and `<<SYS>>`.\n"
|
"If `llama2` - use the llama2 prompt style from the llama_index. Based on `<s>`, `[INST]` and `<<SYS>>`.\n"
|
||||||
|
"If `llama3` - use the llama3 prompt style from the llama_index."
|
||||||
"If `tag` - use the `tag` prompt style. It should look like `<|role|>: message`. \n"
|
"If `tag` - use the `tag` prompt style. It should look like `<|role|>: message`. \n"
|
||||||
"If `mistral` - use the `mistral prompt style. It shoudl look like <s>[INST] {System Prompt} [/INST]</s>[INST] { UserInstructions } [/INST]"
|
"If `mistral` - use the `mistral prompt style. It shoudl look like <s>[INST] {System Prompt} [/INST]</s>[INST] { UserInstructions } [/INST]"
|
||||||
"`llama2` is the historic behaviour. `default` might work better with your custom models."
|
"`llama2` is the historic behaviour. `default` might work better with your custom models."
|
||||||
|
@ -124,7 +124,7 @@ target-version = ['py311']
|
|||||||
target-version = 'py311'
|
target-version = 'py311'
|
||||||
|
|
||||||
# See all rules at https://beta.ruff.rs/docs/rules/
|
# See all rules at https://beta.ruff.rs/docs/rules/
|
||||||
select = [
|
lint.select = [
|
||||||
"E", # pycodestyle
|
"E", # pycodestyle
|
||||||
"W", # pycodestyle
|
"W", # pycodestyle
|
||||||
"F", # Pyflakes
|
"F", # Pyflakes
|
||||||
@ -141,7 +141,7 @@ select = [
|
|||||||
"RUF", # Ruff-specific rules
|
"RUF", # Ruff-specific rules
|
||||||
]
|
]
|
||||||
|
|
||||||
ignore = [
|
lint.ignore = [
|
||||||
"E501", # "Line too long"
|
"E501", # "Line too long"
|
||||||
# -> line length already regulated by black
|
# -> line length already regulated by black
|
||||||
"PT011", # "pytest.raises() should specify expected exception"
|
"PT011", # "pytest.raises() should specify expected exception"
|
||||||
@ -159,24 +159,24 @@ ignore = [
|
|||||||
# -> "Missing docstring in public function too restrictive"
|
# -> "Missing docstring in public function too restrictive"
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff.pydocstyle]
|
[tool.ruff.lint.pydocstyle]
|
||||||
# Automatically disable rules that are incompatible with Google docstring convention
|
# Automatically disable rules that are incompatible with Google docstring convention
|
||||||
convention = "google"
|
convention = "google"
|
||||||
|
|
||||||
[tool.ruff.pycodestyle]
|
[tool.ruff.lint.pycodestyle]
|
||||||
max-doc-length = 88
|
max-doc-length = 88
|
||||||
|
|
||||||
[tool.ruff.flake8-tidy-imports]
|
[tool.ruff.lint.flake8-tidy-imports]
|
||||||
ban-relative-imports = "all"
|
ban-relative-imports = "all"
|
||||||
|
|
||||||
[tool.ruff.flake8-type-checking]
|
[tool.ruff.lint.flake8-type-checking]
|
||||||
strict = true
|
strict = true
|
||||||
runtime-evaluated-base-classes = ["pydantic.BaseModel"]
|
runtime-evaluated-base-classes = ["pydantic.BaseModel"]
|
||||||
# Pydantic needs to be able to evaluate types at runtime
|
# Pydantic needs to be able to evaluate types at runtime
|
||||||
# see https://pypi.org/project/flake8-type-checking/ for flake8-type-checking documentation
|
# see https://pypi.org/project/flake8-type-checking/ for flake8-type-checking documentation
|
||||||
# see https://beta.ruff.rs/docs/settings/#flake8-type-checking-runtime-evaluated-base-classes for ruff documentation
|
# see https://beta.ruff.rs/docs/settings/#flake8-type-checking-runtime-evaluated-base-classes for ruff documentation
|
||||||
|
|
||||||
[tool.ruff.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
# Allow missing docstrings for tests
|
# Allow missing docstrings for tests
|
||||||
"tests/**/*.py" = ["D1"]
|
"tests/**/*.py" = ["D1"]
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ from private_gpt.components.llm.prompt_helper import (
|
|||||||
ChatMLPromptStyle,
|
ChatMLPromptStyle,
|
||||||
DefaultPromptStyle,
|
DefaultPromptStyle,
|
||||||
Llama2PromptStyle,
|
Llama2PromptStyle,
|
||||||
|
Llama3PromptStyle,
|
||||||
MistralPromptStyle,
|
MistralPromptStyle,
|
||||||
TagPromptStyle,
|
TagPromptStyle,
|
||||||
get_prompt_style,
|
get_prompt_style,
|
||||||
@ -139,3 +140,57 @@ def test_llama2_prompt_style_with_system_prompt():
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_llama3_prompt_style_format():
|
||||||
|
prompt_style = Llama3PromptStyle()
|
||||||
|
messages = [
|
||||||
|
ChatMessage(content="You are a helpful assistant", role=MessageRole.SYSTEM),
|
||||||
|
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
|
||||||
|
]
|
||||||
|
|
||||||
|
expected_prompt = (
|
||||||
|
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
|
||||||
|
"You are a helpful assistant<|eot_id|>"
|
||||||
|
"<|start_header_id|>user<|end_header_id|>\n\n"
|
||||||
|
"Hello, how are you doing?<|eot_id|>"
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_llama3_prompt_style_with_default_system():
|
||||||
|
prompt_style = Llama3PromptStyle()
|
||||||
|
messages = [
|
||||||
|
ChatMessage(content="Hello!", role=MessageRole.USER),
|
||||||
|
]
|
||||||
|
expected = (
|
||||||
|
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
|
||||||
|
f"{prompt_style.DEFAULT_SYSTEM_PROMPT}<|eot_id|>"
|
||||||
|
"<|start_header_id|>user<|end_header_id|>\n\nHello!<|eot_id|>"
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
assert prompt_style._messages_to_prompt(messages) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_llama3_prompt_style_with_assistant_response():
|
||||||
|
prompt_style = Llama3PromptStyle()
|
||||||
|
messages = [
|
||||||
|
ChatMessage(content="You are a helpful assistant", role=MessageRole.SYSTEM),
|
||||||
|
ChatMessage(content="What is the capital of France?", role=MessageRole.USER),
|
||||||
|
ChatMessage(
|
||||||
|
content="The capital of France is Paris.", role=MessageRole.ASSISTANT
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
expected_prompt = (
|
||||||
|
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
|
||||||
|
"You are a helpful assistant<|eot_id|>"
|
||||||
|
"<|start_header_id|>user<|end_header_id|>\n\n"
|
||||||
|
"What is the capital of France?<|eot_id|>"
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
"The capital of France is Paris.<|eot_id|>"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
||||||
|
Loading…
Reference in New Issue
Block a user