From 9111d3a6369da71eb4c78d69bb20d20d00475d9a Mon Sep 17 00:00:00 2001 From: Nestor Qin Date: Tue, 23 Apr 2024 18:40:39 -0400 Subject: [PATCH] community[patch]: Fix message formatting for Anthropic models on Amazon Bedrock (#20801) **Description:** This PR fixes an issue in message formatting function for Anthropic models on Amazon Bedrock. Currently, LangChain BedrockChat model will crash if it uses Anthropic models and the model return a message in the following type: - `AIMessageChunk` Moreover, when use BedrockChat with for building Agent, the following message types will trigger the same issue too: - `HumanMessageChunk` - `FunctionMessage` **Issue:** https://github.com/langchain-ai/langchain/issues/18831 **Dependencies:** No. **Testing:** Manually tested. The following code was failing before the patch and works after. ``` @tool def square_root(x: str): "Useful when you need to calculate the square root of a number" return math.sqrt(int(x)) llm = ChatBedrock( model_id="anthropic.claude-3-sonnet-20240229-v1:0", model_kwargs={ "temperature": 0.0 }, ) prompt = ChatPromptTemplate.from_messages( [ ("system", FUNCTION_CALL_PROMPT), ("human", "Question: {user_input}"), MessagesPlaceholder(variable_name="agent_scratchpad"), ] ) tools = [square_root] tools_string = format_tool_to_anthropic_function(square_root) agent = ( RunnablePassthrough.assign( user_input=lambda x: x['user_input'], agent_scratchpad=lambda x: format_to_openai_function_messages( x["intermediate_steps"] ) ) | prompt | llm | AnthropicFunctionsAgentOutputParser() ) agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, return_intermediate_steps=True) output = agent_executor.invoke({ "user_input": "What is the square root of 2?", "tools_string": tools_string, }) ``` List of messages returned from Bedrock: ``` content='You are a helpful assistant.' content='Question: What is the square root of 2?' content="Okay, let's calculate the square root of 2.\nTo calculate the square root of a number, I can use the square_root tool:\n\n\n \n square_root\n \n <__arg1>2\n \n \n\n\n\n\n\nThe square root of 2 is approximately 1.414213562373095\n\n\n\n\nThe square root of 2 is approximately 1.414213562373095\n" id='run-92363df7-eff6-4849-bbba-fa16a1b2988c'" content='1.4142135623730951' name='square_root' ``` --- libs/community/langchain_community/chat_models/bedrock.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/libs/community/langchain_community/chat_models/bedrock.py b/libs/community/langchain_community/chat_models/bedrock.py index 74eb1af216e..852eb0091f6 100644 --- a/libs/community/langchain_community/chat_models/bedrock.py +++ b/libs/community/langchain_community/chat_models/bedrock.py @@ -193,7 +193,13 @@ class ChatPromptAdapter: ) -_message_type_lookups = {"human": "user", "ai": "assistant"} +_message_type_lookups = { + "human": "user", + "ai": "assistant", + "AIMessageChunk": "assistant", + "HumanMessageChunk": "user", + "function": "user", +} @deprecated(