mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
openai: ignore function_calls if tool_calls are present (#31198)
Some providers include (legacy) function calls in `additional_kwargs` in addition to tool calls. We currently unpack both function calls and tool calls if present, but OpenAI will raise 400 in this case. This can come up if providers are mixed in a tool-calling loop. Example: ```python from langchain.chat_models import init_chat_model from langchain_core.messages import HumanMessage from langchain_core.tools import tool @tool def get_weather(location: str) -> str: """Get weather at a location.""" return "It's sunny." gemini = init_chat_model("google_genai:gemini-2.0-flash-001").bind_tools([get_weather]) openai = init_chat_model("openai:gpt-4.1-mini").bind_tools([get_weather]) input_message = HumanMessage("What's the weather in Boston?") tool_call_message = gemini.invoke([input_message]) assert len(tool_call_message.tool_calls) == 1 tool_call = tool_call_message.tool_calls[0] tool_message = get_weather.invoke(tool_call) response = openai.invoke( # currently raises 400 / BadRequestError [input_message, tool_call_message, tool_message] ) ``` Here we ignore function calls if tool calls are present.
This commit is contained in:
parent
83d006190d
commit
868cfc4a8f
@ -251,8 +251,6 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict["role"] = "user"
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict["role"] = "assistant"
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
if message.tool_calls or message.invalid_tool_calls:
|
||||
message_dict["tool_calls"] = [
|
||||
_lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls
|
||||
@ -267,6 +265,10 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
{k: v for k, v in tool_call.items() if k in tool_call_supported_props}
|
||||
for tool_call in message_dict["tool_calls"]
|
||||
]
|
||||
elif "function_call" in message.additional_kwargs:
|
||||
# OpenAI raises 400 if both function_call and tool_calls are present in the
|
||||
# same message.
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
else:
|
||||
pass
|
||||
# If tool calls present, content null value should be None not empty string.
|
||||
|
@ -614,6 +614,50 @@ def test_openai_invoke_name(mock_client: MagicMock) -> None:
|
||||
assert res.name == "Erick"
|
||||
|
||||
|
||||
def test_function_calls_with_tool_calls(mock_client: MagicMock) -> None:
|
||||
# Test that we ignore function calls if tool_calls are present
|
||||
llm = ChatOpenAI(model="gpt-4.1-mini")
|
||||
tool_call_message = AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "Boston"}',
|
||||
}
|
||||
},
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"args": {"location": "Boston"},
|
||||
"id": "abc123",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
)
|
||||
messages = [
|
||||
HumanMessage("What's the weather in Boston?"),
|
||||
tool_call_message,
|
||||
ToolMessage(content="It's sunny.", name="get_weather", tool_call_id="abc123"),
|
||||
]
|
||||
with patch.object(llm, "client", mock_client):
|
||||
_ = llm.invoke(messages)
|
||||
_, call_kwargs = mock_client.create.call_args
|
||||
call_messages = call_kwargs["messages"]
|
||||
tool_call_message_payload = call_messages[1]
|
||||
assert "tool_calls" in tool_call_message_payload
|
||||
assert "function_call" not in tool_call_message_payload
|
||||
|
||||
# Test we don't ignore function calls if tool_calls are not present
|
||||
cast(AIMessage, messages[1]).tool_calls = []
|
||||
with patch.object(llm, "client", mock_client):
|
||||
_ = llm.invoke(messages)
|
||||
_, call_kwargs = mock_client.create.call_args
|
||||
call_messages = call_kwargs["messages"]
|
||||
tool_call_message_payload = call_messages[1]
|
||||
assert "function_call" in tool_call_message_payload
|
||||
assert "tool_calls" not in tool_call_message_payload
|
||||
|
||||
|
||||
def test_custom_token_counting() -> None:
|
||||
def token_encoder(text: str) -> list[int]:
|
||||
return [1, 2, 3]
|
||||
|
Loading…
Reference in New Issue
Block a user