mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-30 19:49:09 +00:00
community[patch]: Bedrock add support for mistral models (#18756)
*Description**: My previous [PR](https://github.com/langchain-ai/langchain/pull/18521) was mistakenly closed, so I am reopening this one. Context: AWS released two Mistral models on Bedrock last Friday (March 1, 2024). This PR includes some code adjustments to ensure their compatibility with the Bedrock class. --------- Co-authored-by: Anis ZAKARI <anis.zakari@hymaia.com> Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
66576948e0
commit
37e89ba5b1
@ -5,7 +5,14 @@ from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import Extra
|
||||
|
||||
@ -20,6 +27,27 @@ from langchain_community.utilities.anthropic import (
|
||||
)
|
||||
|
||||
|
||||
def _convert_one_message_to_text_mistral(message: BaseMessage) -> str:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_text = f"[INST] {message.content} [/INST]"
|
||||
elif isinstance(message, AIMessage):
|
||||
message_text = f"{message.content}"
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_text = f"<<SYS>> {message.content} <</SYS>>"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_text
|
||||
|
||||
|
||||
def convert_messages_to_prompt_mistral(messages: List[BaseMessage]) -> str:
|
||||
"""Convert a list of messages to a prompt for mistral."""
|
||||
return "\n".join(
|
||||
[_convert_one_message_to_text_mistral(message) for message in messages]
|
||||
)
|
||||
|
||||
|
||||
def _format_image(image_url: str) -> Dict:
|
||||
"""
|
||||
Formats an image of format data:image/jpeg;base64,{b64_string}
|
||||
@ -137,6 +165,8 @@ class ChatPromptAdapter:
|
||||
prompt = convert_messages_to_prompt_anthropic(messages=messages)
|
||||
elif provider == "meta":
|
||||
prompt = convert_messages_to_prompt_llama(messages=messages)
|
||||
elif provider == "mistral":
|
||||
prompt = convert_messages_to_prompt_mistral(messages=messages)
|
||||
elif provider == "amazon":
|
||||
prompt = convert_messages_to_prompt_anthropic(
|
||||
messages=messages,
|
||||
|
@ -103,6 +103,7 @@ class LLMInputOutputAdapter:
|
||||
"amazon": "outputText",
|
||||
"cohere": "text",
|
||||
"meta": "generation",
|
||||
"mistral": "outputs",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -127,7 +128,7 @@ class LLMInputOutputAdapter:
|
||||
input_body["prompt"] = _human_assistant_format(prompt)
|
||||
if "max_tokens_to_sample" not in input_body:
|
||||
input_body["max_tokens_to_sample"] = 1024
|
||||
elif provider in ("ai21", "cohere", "meta"):
|
||||
elif provider in ("ai21", "cohere", "meta", "mistral"):
|
||||
input_body["prompt"] = prompt
|
||||
elif provider == "amazon":
|
||||
input_body = dict()
|
||||
@ -156,6 +157,8 @@ class LLMInputOutputAdapter:
|
||||
text = response_body.get("generations")[0].get("text")
|
||||
elif provider == "meta":
|
||||
text = response_body.get("generation")
|
||||
elif provider == "mistral":
|
||||
text = response_body.get("outputs")[0].get("text")
|
||||
else:
|
||||
text = response_body.get("results")[0].get("outputText")
|
||||
|
||||
@ -198,6 +201,13 @@ class LLMInputOutputAdapter:
|
||||
chunk_obj["is_finished"] or chunk_obj[output_key] == "<EOS_TOKEN>"
|
||||
):
|
||||
return
|
||||
|
||||
elif (
|
||||
provider == "mistral"
|
||||
and chunk_obj.get(output_key, [{}])[0].get("stop_reason", "") == "stop"
|
||||
):
|
||||
return
|
||||
|
||||
elif messages_api and (chunk_obj.get("type") == "content_block_stop"):
|
||||
return
|
||||
|
||||
@ -214,11 +224,17 @@ class LLMInputOutputAdapter:
|
||||
else:
|
||||
# chunk obj format varies with provider
|
||||
yield GenerationChunk(
|
||||
text=chunk_obj[output_key],
|
||||
text=(
|
||||
chunk_obj[output_key]
|
||||
if provider != "mistral"
|
||||
else chunk_obj[output_key][0]["text"]
|
||||
),
|
||||
generation_info={
|
||||
GUARDRAILS_BODY_KEY: chunk_obj.get(GUARDRAILS_BODY_KEY)
|
||||
if GUARDRAILS_BODY_KEY in chunk_obj
|
||||
else None,
|
||||
GUARDRAILS_BODY_KEY: (
|
||||
chunk_obj.get(GUARDRAILS_BODY_KEY)
|
||||
if GUARDRAILS_BODY_KEY in chunk_obj
|
||||
else None
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@ -250,7 +266,19 @@ class LLMInputOutputAdapter:
|
||||
):
|
||||
return
|
||||
|
||||
yield GenerationChunk(text=chunk_obj[output_key])
|
||||
if (
|
||||
provider == "mistral"
|
||||
and chunk_obj.get(output_key, [{}])[0].get("stop_reason", "") == "stop"
|
||||
):
|
||||
return
|
||||
|
||||
yield GenerationChunk(
|
||||
text=(
|
||||
chunk_obj[output_key]
|
||||
if provider != "mistral"
|
||||
else chunk_obj[output_key][0]["text"]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class BedrockBase(BaseModel, ABC):
|
||||
@ -300,6 +328,7 @@ class BedrockBase(BaseModel, ABC):
|
||||
"amazon": "stopSequences",
|
||||
"ai21": "stop_sequences",
|
||||
"cohere": "stop_sequences",
|
||||
"mistral": "stop_sequences",
|
||||
}
|
||||
|
||||
guardrails: Optional[Mapping[str, Any]] = {
|
||||
|
Loading…
Reference in New Issue
Block a user