feat: add mistral + chatml prompts (#1426)

This commit is contained in:
CognitiveTech
2024-01-16 16:51:14 -05:00
committed by GitHub
parent 6191bcdbd6
commit e326126d0d
6 changed files with 107 additions and 5 deletions

View File

@@ -123,8 +123,51 @@ class TagPromptStyle(AbstractPromptStyle):
)
class MistralPromptStyle(AbstractPromptStyle):
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
prompt = "<s>"
for message in messages:
role = message.role
content = message.content or ""
if role.lower() == "system":
message_from_user = f"[INST] {content.strip()} [/INST]"
prompt += message_from_user
elif role.lower() == "user":
prompt += "</s>"
message_from_user = f"[INST] {content.strip()} [/INST]"
prompt += message_from_user
return prompt
def _completion_to_prompt(self, completion: str) -> str:
return self._messages_to_prompt(
[ChatMessage(content=completion, role=MessageRole.USER)]
)
class ChatMLPromptStyle(AbstractPromptStyle):
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
prompt = "<|im_start|>system\n"
for message in messages:
role = message.role
content = message.content or ""
if role.lower() == "system":
message_from_user = f"{content.strip()}"
prompt += message_from_user
elif role.lower() == "user":
prompt += "<|im_end|>\n<|im_start|>user\n"
message_from_user = f"{content.strip()}<|im_end|>\n"
prompt += message_from_user
prompt += "<|im_start|>assistant\n"
return prompt
def _completion_to_prompt(self, completion: str) -> str:
return self._messages_to_prompt(
[ChatMessage(content=completion, role=MessageRole.USER)]
)
def get_prompt_style(
prompt_style: Literal["default", "llama2", "tag"] | None
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None
) -> AbstractPromptStyle:
"""Get the prompt style to use from the given string.
@@ -137,4 +180,8 @@ def get_prompt_style(
return Llama2PromptStyle()
elif prompt_style == "tag":
return TagPromptStyle()
elif prompt_style == "mistral":
return MistralPromptStyle()
elif prompt_style == "chatml":
return ChatMLPromptStyle()
raise ValueError(f"Unknown prompt_style='{prompt_style}'")