feat(openai): add tool search support with defer_loading via extras

This commit is contained in:
Open SWE
2026-03-05 20:04:24 +00:00
parent 7a4cc3ec32
commit fc9b39db3f
6 changed files with 150 additions and 7 deletions

View File

@@ -508,6 +508,8 @@ _WellKnownOpenAITools = (
"image_generation",
"web_search_preview",
"web_search",
"tool_search",
"namespace",
)

2
libs/core/uv.lock generated
View File

@@ -992,7 +992,7 @@ wheels = [
[[package]]
name = "langchain-core"
version = "1.2.16"
version = "1.2.17"
source = { editable = "." }
dependencies = [
{ name = "jsonpatch" },

View File

@@ -166,8 +166,15 @@ WellKnownTools = (
"code_interpreter",
"mcp",
"image_generation",
"tool_search",
"namespace",
)
_OPENAI_EXTRA_FIELDS: set[str] = {
"defer_loading",
}
"""Valid OpenAI-specific extra fields that are promoted from BaseTool.extras."""
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
"""Convert a dictionary to a LangChain message.
@@ -1981,9 +1988,18 @@ class BaseChatOpenAI(BaseChatModel):
""" # noqa: E501
if parallel_tool_calls is not None:
kwargs["parallel_tool_calls"] = parallel_tool_calls
formatted_tools = [
convert_to_openai_tool(tool, strict=strict) for tool in tools
]
formatted_tools = []
for tool in tools:
formatted = convert_to_openai_tool(tool, strict=strict)
if (
isinstance(tool, BaseTool)
and hasattr(tool, "extras")
and isinstance(tool.extras, dict)
):
for key, value in tool.extras.items():
if key in _OPENAI_EXTRA_FIELDS:
formatted[key] = value
formatted_tools.append(formatted)
tool_names = []
for tool in formatted_tools:
if "function" in tool:
@@ -3981,7 +3997,11 @@ def _construct_responses_api_payload(
# chat api: {"type": "function", "function": {"name": "...", "description": "...", "parameters": {...}, "strict": ...}} # noqa: E501
# responses api: {"type": "function", "name": "...", "description": "...", "parameters": {...}, "strict": ...} # noqa: E501
if tool["type"] == "function" and "function" in tool:
new_tools.append({"type": "function", **tool["function"]})
flattened = {"type": "function", **tool["function"]}
for key in _OPENAI_EXTRA_FIELDS:
if key in tool:
flattened[key] = tool[key]
new_tools.append(flattened)
else:
if tool["type"] == "image_generation":
# Handle partial images (not yet supported)

View File

@@ -1267,3 +1267,60 @@ def test_csv_input() -> None:
"3" in str(response2.content).lower()
or "three" in str(response2.content).lower()
)
def test_tool_search() -> None:
"""Test tool search with defer_loading via extras on BaseTool."""
from langchain_core.tools import tool
@tool(extras={"defer_loading": True})
def get_weather(location: str) -> str:
"""Get the current weather for a location."""
return f"Sunny in {location}"
@tool(extras={"defer_loading": True})
def get_population(city: str) -> str:
"""Get the population of a city."""
return f"Population of {city}: 1,000,000"
llm = ChatOpenAI(model="gpt-4.1-mini")
bound = llm.bind_tools(
[get_weather, get_population, {"type": "tool_search"}],
parallel_tool_calls=False,
)
response = cast(AIMessage, bound.invoke("What's the weather in San Francisco?"))
assert response.tool_calls
assert response.tool_calls[0]["name"] == "get_weather"
def test_tool_search_with_namespace() -> None:
"""Test tool search with namespace and defer_loading."""
weather_ns = {
"type": "namespace",
"name": "weather",
"description": "Weather tools for looking up current conditions.",
"tools": [
{
"type": "function",
"name": "get_weather",
"description": "Get the current weather for a location.",
"defer_loading": True,
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string"},
},
"required": ["location"],
"additionalProperties": False,
},
}
],
}
llm = ChatOpenAI(model="gpt-4.1-mini")
bound = llm.bind_tools(
[weather_ns, {"type": "tool_search"}],
parallel_tool_calls=False,
)
response = cast(AIMessage, bound.invoke("What's the weather in San Francisco?"))
assert response.tool_calls
assert response.tool_calls[0]["name"] == "get_weather"

View File

@@ -1,5 +1,5 @@
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.tools import Tool
from langchain_core.tools import Tool, tool
from langchain_openai import ChatOpenAI, custom_tool
@@ -96,6 +96,70 @@ def test_custom_tool() -> None:
assert payload["input"] == expected_input
def test_extras_with_defer_loading() -> None:
@tool(extras={"defer_loading": True})
def get_weather(location: str) -> str:
"""Get the current weather for a location."""
return f"Sunny in {location}"
llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True)
bound = llm.bind_tools(
[get_weather, {"type": "tool_search"}],
)
tools = bound.kwargs["tools"] # type: ignore[attr-defined]
func_tool = next(t for t in tools if t["type"] == "function")
assert func_tool["defer_loading"] is True
assert func_tool["function"]["name"] == "get_weather"
assert any(t["type"] == "tool_search" for t in tools)
from langchain_openai.chat_models.base import _construct_responses_api_payload
payload = _construct_responses_api_payload(
[HumanMessage("hello")],
{"model": "gpt-4.1", "stream": False, "tools": list(tools)},
)
resp_tools = payload["tools"]
resp_func = next(t for t in resp_tools if t["type"] == "function")
assert resp_func["defer_loading"] is True
assert resp_func["name"] == "get_weather"
assert any(t["type"] == "tool_search" for t in resp_tools)
def test_tool_search_dict_passthrough() -> None:
llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True)
tool_search = {"type": "tool_search"}
bound = llm.bind_tools([tool_search])
tools = bound.kwargs["tools"] # type: ignore[attr-defined]
assert any(t["type"] == "tool_search" for t in tools)
def test_namespace_dict_passthrough() -> None:
llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True)
ns = {
"type": "namespace",
"name": "crm",
"description": "CRM tools.",
"tools": [
{
"type": "function",
"name": "list_orders",
"description": "List orders.",
"defer_loading": True,
"parameters": {
"type": "object",
"properties": {"customer_id": {"type": "string"}},
"required": ["customer_id"],
"additionalProperties": False,
},
}
],
}
bound = llm.bind_tools([ns, {"type": "tool_search"}])
tools = bound.kwargs["tools"] # type: ignore[attr-defined]
assert any(t["type"] == "namespace" for t in tools)
assert any(t["type"] == "tool_search" for t in tools)
async def test_async_custom_tool() -> None:
@custom_tool
async def my_async_tool(x: str) -> str:

View File

@@ -610,7 +610,7 @@ typing = [
[[package]]
name = "langchain-core"
version = "1.2.16"
version = "1.2.17"
source = { editable = "../../core" }
dependencies = [
{ name = "jsonpatch" },