From 071bc7ba59fbe67477de1b811f5413757ba9746d Mon Sep 17 00:00:00 2001 From: Bagatur Date: Wed, 12 Mar 2025 01:03:42 -0700 Subject: [PATCH] update --- .../langchain_openai/chat_models/base.py | 39 +++++++++++++-- .../tests/unit_tests/chat_models/test_base.py | 47 ++++++++++++++----- 2 files changed, 69 insertions(+), 17 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 55d65874b9f..5b28cc4b003 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -91,6 +91,7 @@ from langchain_core.runnables import ( ) from langchain_core.runnables.config import run_in_executor from langchain_core.tools import BaseTool +from langchain_core.tools.base import _stringify from langchain_core.utils import get_pydantic_field_names from langchain_core.utils.function_calling import ( convert_to_openai_function, @@ -2765,18 +2766,22 @@ def _construct_response_api_input(messages: Sequence[BaseMessage]) -> list: for lc_msg in messages: msg = _convert_message_to_dict(lc_msg) if msg["role"] == "tool": + tool_output = msg["content"] + if not isinstance(tool_output, str): + tool_output = _stringify(tool_output) function_call_output = { "type": "function_call_output", - "output": msg["content"], + "output": tool_output, "call_id": msg["tool_call_id"], } input_.append(function_call_output) elif msg["role"] == "assistant": - if msg.get("content"): - input_.append(msg) + function_calls = [] if tool_calls := msg.pop("tool_calls", None): + # TODO: should you be able to preserve the function call object id on + # the langchain tool calls themselves? if not lc_msg.additional_kwargs.get(_FUNCTION_CALL_IDS_MAP_KEY): - raise ValueError(...) + raise ValueError("") function_call_ids = lc_msg.additional_kwargs[_FUNCTION_CALL_IDS_MAP_KEY] for tool_call in tool_calls: function_call = { @@ -2786,7 +2791,31 @@ def _construct_response_api_input(messages: Sequence[BaseMessage]) -> list: "call_id": tool_call["id"], "id": function_call_ids[tool_call["id"]], } - input_.append(function_call) + function_calls.append(function_call) + + msg["content"] = msg.get("content") or [] + if lc_msg.additional_kwargs.get("refusal"): + if isinstance(msg["content"], str): + msg["content"] = [ + { + "type": "output_text", + "text": msg["content"], + "annotations": [], + } + ] + msg["content"].append( + {"type": "refusal", "refusal": lc_msg.additional_kwargs["refusal"]} + ) + if isinstance(msg["content"], list): + for block in msg["content"]: + # chat api: {"type": "text", "text": "..."} + # response api: {"type": "output_text", "text": "...", "annotations": [...]} # noqa: E501 + if block["type"] == "text": + block["type"] = "output_text" + block["annotations"] = block.get("annotations") or [] + if msg["content"]: + input_.append(msg) + input_.extend(function_calls) elif msg["role"] == "user": if isinstance(msg["content"], list): for block in msg["content"]: diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 938c0bc8115..ebfa2b83fbc 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -1606,6 +1606,9 @@ def test__construct_response_api_input_multiple_message_types() -> None: messages = [ SystemMessage(content="You are a helpful assistant."), HumanMessage(content="What's the weather in San Francisco?"), + HumanMessage( + content=[{"type": "text", "text": "What's the weather in San Francisco?"}] + ), AIMessage( content="", tool_calls=[ @@ -1623,11 +1626,19 @@ def test__construct_response_api_input_multiple_message_types() -> None: tool_call_id="call_123", ), AIMessage(content="The weather in San Francisco is 72°F and sunny."), + AIMessage( + content=[ + { + "type": "text", + "text": "The weather in San Francisco is 72°F and sunny.", + } + ] + ), ] result = _construct_response_api_input(messages) - assert len(result) == 5 + assert len(result) == len(messages) # Check system message assert result[0]["role"] == "system" @@ -1636,19 +1647,31 @@ def test__construct_response_api_input_multiple_message_types() -> None: # Check human message assert result[1]["role"] == "user" assert result[1]["content"] == "What's the weather in San Francisco?" + assert result[2]["role"] == "user" + assert result[2]["content"] == [ + {"type": "input_text", "text": "What's the weather in San Francisco?"} + ] # Check function call - assert result[2]["type"] == "function_call" - assert result[2]["name"] == "get_weather" - assert result[2]["arguments"] == '{"location": "San Francisco"}' - assert result[2]["call_id"] == "call_123" - assert result[2]["id"] == "func_456" + assert result[3]["type"] == "function_call" + assert result[3]["name"] == "get_weather" + assert result[3]["arguments"] == '{"location": "San Francisco"}' + assert result[3]["call_id"] == "call_123" + assert result[3]["id"] == "func_456" # Check function call output - assert result[3]["type"] == "function_call_output" - assert result[3]["output"] == '{"temperature": 72, "conditions": "sunny"}' - assert result[3]["call_id"] == "call_123" + assert result[4]["type"] == "function_call_output" + assert result[4]["output"] == '{"temperature": 72, "conditions": "sunny"}' + assert result[4]["call_id"] == "call_123" - # Check final AI message - assert result[4]["role"] == "assistant" - assert result[4]["content"] == "The weather in San Francisco is 72°F and sunny." + assert result[5]["role"] == "assistant" + assert result[5]["content"] == "The weather in San Francisco is 72°F and sunny." + + assert result[6]["role"] == "assistant" + assert result[6]["content"] == [ + { + "type": "output_text", + "text": "The weather in San Francisco is 72°F and sunny.", + "annotations": [], + } + ]