mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-05 13:06:03 +00:00
community[patch]: Bedrock add support for mistral models (#18756)
*Description**: My previous [PR](https://github.com/langchain-ai/langchain/pull/18521) was mistakenly closed, so I am reopening this one. Context: AWS released two Mistral models on Bedrock last Friday (March 1, 2024). This PR includes some code adjustments to ensure their compatibility with the Bedrock class. --------- Co-authored-by: Anis ZAKARI <anis.zakari@hymaia.com> Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -5,7 +5,14 @@ from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import Extra
|
||||
|
||||
@@ -20,6 +27,27 @@ from langchain_community.utilities.anthropic import (
|
||||
)
|
||||
|
||||
|
||||
def _convert_one_message_to_text_mistral(message: BaseMessage) -> str:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_text = f"[INST] {message.content} [/INST]"
|
||||
elif isinstance(message, AIMessage):
|
||||
message_text = f"{message.content}"
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_text = f"<<SYS>> {message.content} <</SYS>>"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_text
|
||||
|
||||
|
||||
def convert_messages_to_prompt_mistral(messages: List[BaseMessage]) -> str:
|
||||
"""Convert a list of messages to a prompt for mistral."""
|
||||
return "\n".join(
|
||||
[_convert_one_message_to_text_mistral(message) for message in messages]
|
||||
)
|
||||
|
||||
|
||||
def _format_image(image_url: str) -> Dict:
|
||||
"""
|
||||
Formats an image of format data:image/jpeg;base64,{b64_string}
|
||||
@@ -137,6 +165,8 @@ class ChatPromptAdapter:
|
||||
prompt = convert_messages_to_prompt_anthropic(messages=messages)
|
||||
elif provider == "meta":
|
||||
prompt = convert_messages_to_prompt_llama(messages=messages)
|
||||
elif provider == "mistral":
|
||||
prompt = convert_messages_to_prompt_mistral(messages=messages)
|
||||
elif provider == "amazon":
|
||||
prompt = convert_messages_to_prompt_anthropic(
|
||||
messages=messages,
|
||||
|
Reference in New Issue
Block a user