community[patch]: Bedrock add support for mistral models (#18756)

*Description**: My previous
[PR](https://github.com/langchain-ai/langchain/pull/18521) was
mistakenly closed, so I am reopening this one. Context: AWS released two
Mistral models on Bedrock last Friday (March 1, 2024). This PR includes
some code adjustments to ensure their compatibility with the Bedrock
class.

---------

Co-authored-by: Anis ZAKARI <anis.zakari@hymaia.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Anis ZAKARI 2024-03-09 02:20:38 +01:00 committed by GitHub
parent 66576948e0
commit 37e89ba5b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 66 additions and 7 deletions

View File

@ -5,7 +5,14 @@ from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Extra
@ -20,6 +27,27 @@ from langchain_community.utilities.anthropic import (
)
def _convert_one_message_to_text_mistral(message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"[INST] {message.content} [/INST]"
elif isinstance(message, AIMessage):
message_text = f"{message.content}"
elif isinstance(message, SystemMessage):
message_text = f"<<SYS>> {message.content} <</SYS>>"
else:
raise ValueError(f"Got unknown type {message}")
return message_text
def convert_messages_to_prompt_mistral(messages: List[BaseMessage]) -> str:
"""Convert a list of messages to a prompt for mistral."""
return "\n".join(
[_convert_one_message_to_text_mistral(message) for message in messages]
)
def _format_image(image_url: str) -> Dict:
"""
Formats an image of format data:image/jpeg;base64,{b64_string}
@ -137,6 +165,8 @@ class ChatPromptAdapter:
prompt = convert_messages_to_prompt_anthropic(messages=messages)
elif provider == "meta":
prompt = convert_messages_to_prompt_llama(messages=messages)
elif provider == "mistral":
prompt = convert_messages_to_prompt_mistral(messages=messages)
elif provider == "amazon":
prompt = convert_messages_to_prompt_anthropic(
messages=messages,

View File

@ -103,6 +103,7 @@ class LLMInputOutputAdapter:
"amazon": "outputText",
"cohere": "text",
"meta": "generation",
"mistral": "outputs",
}
@classmethod
@ -127,7 +128,7 @@ class LLMInputOutputAdapter:
input_body["prompt"] = _human_assistant_format(prompt)
if "max_tokens_to_sample" not in input_body:
input_body["max_tokens_to_sample"] = 1024
elif provider in ("ai21", "cohere", "meta"):
elif provider in ("ai21", "cohere", "meta", "mistral"):
input_body["prompt"] = prompt
elif provider == "amazon":
input_body = dict()
@ -156,6 +157,8 @@ class LLMInputOutputAdapter:
text = response_body.get("generations")[0].get("text")
elif provider == "meta":
text = response_body.get("generation")
elif provider == "mistral":
text = response_body.get("outputs")[0].get("text")
else:
text = response_body.get("results")[0].get("outputText")
@ -198,6 +201,13 @@ class LLMInputOutputAdapter:
chunk_obj["is_finished"] or chunk_obj[output_key] == "<EOS_TOKEN>"
):
return
elif (
provider == "mistral"
and chunk_obj.get(output_key, [{}])[0].get("stop_reason", "") == "stop"
):
return
elif messages_api and (chunk_obj.get("type") == "content_block_stop"):
return
@ -214,11 +224,17 @@ class LLMInputOutputAdapter:
else:
# chunk obj format varies with provider
yield GenerationChunk(
text=chunk_obj[output_key],
text=(
chunk_obj[output_key]
if provider != "mistral"
else chunk_obj[output_key][0]["text"]
),
generation_info={
GUARDRAILS_BODY_KEY: chunk_obj.get(GUARDRAILS_BODY_KEY)
if GUARDRAILS_BODY_KEY in chunk_obj
else None,
GUARDRAILS_BODY_KEY: (
chunk_obj.get(GUARDRAILS_BODY_KEY)
if GUARDRAILS_BODY_KEY in chunk_obj
else None
),
},
)
@ -250,7 +266,19 @@ class LLMInputOutputAdapter:
):
return
yield GenerationChunk(text=chunk_obj[output_key])
if (
provider == "mistral"
and chunk_obj.get(output_key, [{}])[0].get("stop_reason", "") == "stop"
):
return
yield GenerationChunk(
text=(
chunk_obj[output_key]
if provider != "mistral"
else chunk_obj[output_key][0]["text"]
)
)
class BedrockBase(BaseModel, ABC):
@ -300,6 +328,7 @@ class BedrockBase(BaseModel, ABC):
"amazon": "stopSequences",
"ai21": "stop_sequences",
"cohere": "stop_sequences",
"mistral": "stop_sequences",
}
guardrails: Optional[Mapping[str, Any]] = {