mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-09-03 00:07:23 +00:00
feat: add mistral + chatml prompts (#1426)
This commit is contained in:
@@ -123,8 +123,51 @@ class TagPromptStyle(AbstractPromptStyle):
|
||||
)
|
||||
|
||||
|
||||
class MistralPromptStyle(AbstractPromptStyle):
|
||||
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||
prompt = "<s>"
|
||||
for message in messages:
|
||||
role = message.role
|
||||
content = message.content or ""
|
||||
if role.lower() == "system":
|
||||
message_from_user = f"[INST] {content.strip()} [/INST]"
|
||||
prompt += message_from_user
|
||||
elif role.lower() == "user":
|
||||
prompt += "</s>"
|
||||
message_from_user = f"[INST] {content.strip()} [/INST]"
|
||||
prompt += message_from_user
|
||||
return prompt
|
||||
|
||||
def _completion_to_prompt(self, completion: str) -> str:
|
||||
return self._messages_to_prompt(
|
||||
[ChatMessage(content=completion, role=MessageRole.USER)]
|
||||
)
|
||||
|
||||
|
||||
class ChatMLPromptStyle(AbstractPromptStyle):
|
||||
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||
prompt = "<|im_start|>system\n"
|
||||
for message in messages:
|
||||
role = message.role
|
||||
content = message.content or ""
|
||||
if role.lower() == "system":
|
||||
message_from_user = f"{content.strip()}"
|
||||
prompt += message_from_user
|
||||
elif role.lower() == "user":
|
||||
prompt += "<|im_end|>\n<|im_start|>user\n"
|
||||
message_from_user = f"{content.strip()}<|im_end|>\n"
|
||||
prompt += message_from_user
|
||||
prompt += "<|im_start|>assistant\n"
|
||||
return prompt
|
||||
|
||||
def _completion_to_prompt(self, completion: str) -> str:
|
||||
return self._messages_to_prompt(
|
||||
[ChatMessage(content=completion, role=MessageRole.USER)]
|
||||
)
|
||||
|
||||
|
||||
def get_prompt_style(
|
||||
prompt_style: Literal["default", "llama2", "tag"] | None
|
||||
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None
|
||||
) -> AbstractPromptStyle:
|
||||
"""Get the prompt style to use from the given string.
|
||||
|
||||
@@ -137,4 +180,8 @@ def get_prompt_style(
|
||||
return Llama2PromptStyle()
|
||||
elif prompt_style == "tag":
|
||||
return TagPromptStyle()
|
||||
elif prompt_style == "mistral":
|
||||
return MistralPromptStyle()
|
||||
elif prompt_style == "chatml":
|
||||
return ChatMLPromptStyle()
|
||||
raise ValueError(f"Unknown prompt_style='{prompt_style}'")
|
||||
|
Reference in New Issue
Block a user