From 37e89ba5b11e609f09e12aaeb0fbd1ddca628b2e Mon Sep 17 00:00:00 2001 From: Anis ZAKARI <73587508+AnisZakari@users.noreply.github.com> Date: Sat, 9 Mar 2024 02:20:38 +0100 Subject: [PATCH] community[patch]: Bedrock add support for mistral models (#18756) *Description**: My previous [PR](https://github.com/langchain-ai/langchain/pull/18521) was mistakenly closed, so I am reopening this one. Context: AWS released two Mistral models on Bedrock last Friday (March 1, 2024). This PR includes some code adjustments to ensure their compatibility with the Bedrock class. --------- Co-authored-by: Anis ZAKARI Co-authored-by: Erick Friis --- .../chat_models/bedrock.py | 32 ++++++++++++++- .../langchain_community/llms/bedrock.py | 41 ++++++++++++++++--- 2 files changed, 66 insertions(+), 7 deletions(-) diff --git a/libs/community/langchain_community/chat_models/bedrock.py b/libs/community/langchain_community/chat_models/bedrock.py index f9d87b6274a..933343d6a96 100644 --- a/libs/community/langchain_community/chat_models/bedrock.py +++ b/libs/community/langchain_community/chat_models/bedrock.py @@ -5,7 +5,14 @@ from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, +) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import Extra @@ -20,6 +27,27 @@ from langchain_community.utilities.anthropic import ( ) +def _convert_one_message_to_text_mistral(message: BaseMessage) -> str: + if isinstance(message, ChatMessage): + message_text = f"\n\n{message.role.capitalize()}: {message.content}" + elif isinstance(message, HumanMessage): + message_text = f"[INST] {message.content} [/INST]" + elif isinstance(message, AIMessage): + message_text = f"{message.content}" + elif isinstance(message, SystemMessage): + message_text = f"<> {message.content} <>" + else: + raise ValueError(f"Got unknown type {message}") + return message_text + + +def convert_messages_to_prompt_mistral(messages: List[BaseMessage]) -> str: + """Convert a list of messages to a prompt for mistral.""" + return "\n".join( + [_convert_one_message_to_text_mistral(message) for message in messages] + ) + + def _format_image(image_url: str) -> Dict: """ Formats an image of format data:image/jpeg;base64,{b64_string} @@ -137,6 +165,8 @@ class ChatPromptAdapter: prompt = convert_messages_to_prompt_anthropic(messages=messages) elif provider == "meta": prompt = convert_messages_to_prompt_llama(messages=messages) + elif provider == "mistral": + prompt = convert_messages_to_prompt_mistral(messages=messages) elif provider == "amazon": prompt = convert_messages_to_prompt_anthropic( messages=messages, diff --git a/libs/community/langchain_community/llms/bedrock.py b/libs/community/langchain_community/llms/bedrock.py index d126995b93c..9b7515a5f4d 100644 --- a/libs/community/langchain_community/llms/bedrock.py +++ b/libs/community/langchain_community/llms/bedrock.py @@ -103,6 +103,7 @@ class LLMInputOutputAdapter: "amazon": "outputText", "cohere": "text", "meta": "generation", + "mistral": "outputs", } @classmethod @@ -127,7 +128,7 @@ class LLMInputOutputAdapter: input_body["prompt"] = _human_assistant_format(prompt) if "max_tokens_to_sample" not in input_body: input_body["max_tokens_to_sample"] = 1024 - elif provider in ("ai21", "cohere", "meta"): + elif provider in ("ai21", "cohere", "meta", "mistral"): input_body["prompt"] = prompt elif provider == "amazon": input_body = dict() @@ -156,6 +157,8 @@ class LLMInputOutputAdapter: text = response_body.get("generations")[0].get("text") elif provider == "meta": text = response_body.get("generation") + elif provider == "mistral": + text = response_body.get("outputs")[0].get("text") else: text = response_body.get("results")[0].get("outputText") @@ -198,6 +201,13 @@ class LLMInputOutputAdapter: chunk_obj["is_finished"] or chunk_obj[output_key] == "" ): return + + elif ( + provider == "mistral" + and chunk_obj.get(output_key, [{}])[0].get("stop_reason", "") == "stop" + ): + return + elif messages_api and (chunk_obj.get("type") == "content_block_stop"): return @@ -214,11 +224,17 @@ class LLMInputOutputAdapter: else: # chunk obj format varies with provider yield GenerationChunk( - text=chunk_obj[output_key], + text=( + chunk_obj[output_key] + if provider != "mistral" + else chunk_obj[output_key][0]["text"] + ), generation_info={ - GUARDRAILS_BODY_KEY: chunk_obj.get(GUARDRAILS_BODY_KEY) - if GUARDRAILS_BODY_KEY in chunk_obj - else None, + GUARDRAILS_BODY_KEY: ( + chunk_obj.get(GUARDRAILS_BODY_KEY) + if GUARDRAILS_BODY_KEY in chunk_obj + else None + ), }, ) @@ -250,7 +266,19 @@ class LLMInputOutputAdapter: ): return - yield GenerationChunk(text=chunk_obj[output_key]) + if ( + provider == "mistral" + and chunk_obj.get(output_key, [{}])[0].get("stop_reason", "") == "stop" + ): + return + + yield GenerationChunk( + text=( + chunk_obj[output_key] + if provider != "mistral" + else chunk_obj[output_key][0]["text"] + ) + ) class BedrockBase(BaseModel, ABC): @@ -300,6 +328,7 @@ class BedrockBase(BaseModel, ABC): "amazon": "stopSequences", "ai21": "stop_sequences", "cohere": "stop_sequences", + "mistral": "stop_sequences", } guardrails: Optional[Mapping[str, Any]] = {