mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-09-03 08:15:14 +00:00
fix(LLM): mistral ignoring assistant messages (#1954)
* fix: mistral ignoring assistant messages * fix: typing * fix: fix tests
This commit is contained in:
@@ -173,18 +173,22 @@ class TagPromptStyle(AbstractPromptStyle):
|
||||
|
||||
class MistralPromptStyle(AbstractPromptStyle):
|
||||
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||
prompt = "<s>"
|
||||
inst_buffer = []
|
||||
text = ""
|
||||
for message in messages:
|
||||
role = message.role
|
||||
content = message.content or ""
|
||||
if role.lower() == "system":
|
||||
message_from_user = f"[INST] {content.strip()} [/INST]"
|
||||
prompt += message_from_user
|
||||
elif role.lower() == "user":
|
||||
prompt += "</s>"
|
||||
message_from_user = f"[INST] {content.strip()} [/INST]"
|
||||
prompt += message_from_user
|
||||
return prompt
|
||||
if message.role == MessageRole.SYSTEM or message.role == MessageRole.USER:
|
||||
inst_buffer.append(str(message.content).strip())
|
||||
elif message.role == MessageRole.ASSISTANT:
|
||||
text += "<s>[INST] " + "\n".join(inst_buffer) + " [/INST]"
|
||||
text += " " + str(message.content).strip() + "</s>"
|
||||
inst_buffer.clear()
|
||||
else:
|
||||
raise ValueError(f"Unknown message role {message.role}")
|
||||
|
||||
if len(inst_buffer) > 0:
|
||||
text += "<s>[INST] " + "\n".join(inst_buffer) + " [/INST]"
|
||||
|
||||
return text
|
||||
|
||||
def _completion_to_prompt(self, completion: str) -> str:
|
||||
return self._messages_to_prompt(
|
||||
|
Reference in New Issue
Block a user