diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index cb1b68903b1..5ef15b8a608 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -251,8 +251,6 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: message_dict["role"] = "user" elif isinstance(message, AIMessage): message_dict["role"] = "assistant" - if "function_call" in message.additional_kwargs: - message_dict["function_call"] = message.additional_kwargs["function_call"] if message.tool_calls or message.invalid_tool_calls: message_dict["tool_calls"] = [ _lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls @@ -267,6 +265,10 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: {k: v for k, v in tool_call.items() if k in tool_call_supported_props} for tool_call in message_dict["tool_calls"] ] + elif "function_call" in message.additional_kwargs: + # OpenAI raises 400 if both function_call and tool_calls are present in the + # same message. + message_dict["function_call"] = message.additional_kwargs["function_call"] else: pass # If tool calls present, content null value should be None not empty string. diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index c762b3dab5d..2331d465ef8 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -614,6 +614,50 @@ def test_openai_invoke_name(mock_client: MagicMock) -> None: assert res.name == "Erick" +def test_function_calls_with_tool_calls(mock_client: MagicMock) -> None: + # Test that we ignore function calls if tool_calls are present + llm = ChatOpenAI(model="gpt-4.1-mini") + tool_call_message = AIMessage( + content="", + additional_kwargs={ + "function_call": { + "name": "get_weather", + "arguments": '{"location": "Boston"}', + } + }, + tool_calls=[ + { + "name": "get_weather", + "args": {"location": "Boston"}, + "id": "abc123", + "type": "tool_call", + } + ], + ) + messages = [ + HumanMessage("What's the weather in Boston?"), + tool_call_message, + ToolMessage(content="It's sunny.", name="get_weather", tool_call_id="abc123"), + ] + with patch.object(llm, "client", mock_client): + _ = llm.invoke(messages) + _, call_kwargs = mock_client.create.call_args + call_messages = call_kwargs["messages"] + tool_call_message_payload = call_messages[1] + assert "tool_calls" in tool_call_message_payload + assert "function_call" not in tool_call_message_payload + + # Test we don't ignore function calls if tool_calls are not present + cast(AIMessage, messages[1]).tool_calls = [] + with patch.object(llm, "client", mock_client): + _ = llm.invoke(messages) + _, call_kwargs = mock_client.create.call_args + call_messages = call_kwargs["messages"] + tool_call_message_payload = call_messages[1] + assert "function_call" in tool_call_message_payload + assert "tool_calls" not in tool_call_message_payload + + def test_custom_token_counting() -> None: def token_encoder(text: str) -> list[int]: return [1, 2, 3]