fix: mistral ignoring assistant messages

This commit is contained in:
Pablo Orgaz 2024-05-28 19:30:44 +02:00 committed by GitHub
parent 3b3e96ad6c
commit 136f8b5208
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -173,18 +173,20 @@ class TagPromptStyle(AbstractPromptStyle):
class MistralPromptStyle(AbstractPromptStyle): class MistralPromptStyle(AbstractPromptStyle):
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
prompt = "<s>" inst_buffer = []
text = ""
for message in messages: for message in messages:
role = message.role if message.role == MessageRole.SYSTEM:
content = message.content or "" inst_buffer.append(message.content.strip())
if role.lower() == "system": elif message.role == MessageRole.USER:
message_from_user = f"[INST] {content.strip()} [/INST]" inst_buffer.append(message.content.strip())
prompt += message_from_user text += "<s>[/INST] " + "\n".join(inst_buffer) + " [/INST]"
elif role.lower() == "user": inst_buffer.clear()
prompt += "</s>" elif message.role == MessageRole.ASSISTANT:
message_from_user = f"[INST] {content.strip()} [/INST]" text += " " + message.content.strip() + "</s>"
prompt += message_from_user else:
return prompt raise ValueError(f"Unknown message role {message.role}")
return text
def _completion_to_prompt(self, completion: str) -> str: def _completion_to_prompt(self, completion: str) -> str:
return self._messages_to_prompt( return self._messages_to_prompt(