feat(openai): support tool search (#35582)

This commit is contained in:
ccurme
2026-03-08 08:53:13 -04:00
committed by GitHub
parent 532b014f5c
commit fbfe4b812d
13 changed files with 514 additions and 13 deletions

View File

@@ -731,6 +731,11 @@ def _convert_to_v1_from_responses(message: AIMessage) -> list[types.ContentBlock
tool_call_block["extras"]["item_id"] = block["id"]
if "index" in block:
tool_call_block["index"] = f"lc_tc_{block['index']}"
for extra_key in ("status", "namespace"):
if extra_key in block:
if "extras" not in tool_call_block:
tool_call_block["extras"] = {}
tool_call_block["extras"][extra_key] = block[extra_key]
yield tool_call_block
elif block_type == "web_search_call":
@@ -979,6 +984,51 @@ def _convert_to_v1_from_responses(message: AIMessage) -> list[types.ContentBlock
mcp_list_tools_result["index"] = f"lc_mltr_{block['index'] + 1}"
yield cast("types.ServerToolResult", mcp_list_tools_result)
elif (
block_type == "tool_search_call" and block.get("execution") == "server"
):
tool_search_call: dict[str, Any] = {
"type": "server_tool_call",
"name": "tool_search",
"id": block["id"],
"args": block.get("arguments", {}),
}
if "index" in block:
tool_search_call["index"] = f"lc_tsc_{block['index']}"
extras: dict[str, Any] = {}
known = {"type", "id", "arguments", "index"}
for key in block:
if key not in known:
extras[key] = block[key]
if extras:
tool_search_call["extras"] = extras
yield cast("types.ServerToolCall", tool_search_call)
elif (
block_type == "tool_search_output"
and block.get("execution") == "server"
):
tool_search_output: dict[str, Any] = {
"type": "server_tool_result",
"tool_call_id": block["id"],
"output": {"tools": block.get("tools", [])},
}
status = block.get("status")
if status == "failed":
tool_search_output["status"] = "error"
elif status == "completed":
tool_search_output["status"] = "success"
if "index" in block and isinstance(block["index"], int):
tool_search_output["index"] = f"lc_tso_{block['index']}"
extras_out: dict[str, Any] = {"name": "tool_search"}
known_out = {"type", "id", "status", "tools", "index"}
for key in block:
if key not in known_out:
extras_out[key] = block[key]
if extras_out:
tool_search_output["extras"] = extras_out
yield cast("types.ServerToolResult", tool_search_output)
elif block_type in types.KNOWN_BLOCK_TYPES:
yield cast("types.ContentBlock", block)
else:

View File

@@ -508,6 +508,8 @@ _WellKnownOpenAITools = (
"image_generation",
"web_search_preview",
"web_search",
"tool_search",
"namespace",
)

View File

@@ -103,6 +103,8 @@ def _convert_to_v03_ai_message(
"mcp_list_tools",
"mcp_approval_request",
"image_generation_call",
"tool_search_call",
"tool_search_output",
):
# Store built-in tool calls in additional_kwargs
if "tool_outputs" not in message.additional_kwargs:
@@ -420,17 +422,58 @@ def _convert_from_v1_to_responses(
new_block["name"] = block["name"]
if "extras" in block and "arguments" in block["extras"]:
new_block["arguments"] = block["extras"]["arguments"]
if any(key not in block for key in ("name", "arguments")):
if any(key not in new_block for key in ("name", "arguments")):
matching_tool_calls = [
call for call in tool_calls if call["id"] == block["id"]
]
if matching_tool_calls:
tool_call = matching_tool_calls[0]
if "name" not in block:
if "name" not in new_block:
new_block["name"] = tool_call["name"]
if "arguments" not in block:
new_block["arguments"] = json.dumps(tool_call["args"])
if "arguments" not in new_block:
new_block["arguments"] = json.dumps(
tool_call["args"], separators=(",", ":")
)
if "extras" in block:
for extra_key in ("status", "namespace"):
if extra_key in block["extras"]:
new_block[extra_key] = block["extras"][extra_key]
new_content.append(new_block)
elif block["type"] == "server_tool_call" and block.get("name") == "tool_search":
extras = block.get("extras", {})
new_block = {"id": block["id"]}
status = extras.get("status")
if status:
new_block["status"] = status
new_block["type"] = "tool_search_call"
if "args" in block:
new_block["arguments"] = block["args"]
execution = extras.get("execution")
if execution:
new_block["execution"] = execution
new_content.append(new_block)
elif (
block["type"] == "server_tool_result"
and block.get("extras", {}).get("name") == "tool_search"
):
extras = block.get("extras", {})
new_block = {"id": block.get("tool_call_id", "")}
status = block.get("status")
if status == "success":
new_block["status"] = "completed"
elif status == "error":
new_block["status"] = "failed"
elif status:
new_block["status"] = status
new_block["type"] = "tool_search_output"
new_block["execution"] = "server"
output: dict = block.get("output", {})
if isinstance(output, dict) and "tools" in output:
new_block["tools"] = output["tools"]
new_content.append(new_block)
elif (
is_data_content_block(cast(dict, block))
and block["type"] == "image"
@@ -441,7 +484,7 @@ def _convert_from_v1_to_responses(
new_block = {"type": "image_generation_call", "result": block["base64"]}
for extra_key in ("id", "status"):
if extra_key in block:
new_block[extra_key] = block[extra_key] # type: ignore[typeddict-item]
new_block[extra_key] = block[extra_key] # type: ignore[literal-required]
elif extra_key in block.get("extras", {}):
new_block[extra_key] = block["extras"][extra_key]
new_content.append(new_block)

View File

@@ -166,6 +166,7 @@ WellKnownTools = (
"code_interpreter",
"mcp",
"image_generation",
"tool_search",
)
@@ -1984,6 +1985,14 @@ class BaseChatOpenAI(BaseChatModel):
formatted_tools = [
convert_to_openai_tool(tool, strict=strict) for tool in tools
]
for original, formatted in zip(tools, formatted_tools, strict=False):
if (
isinstance(original, BaseTool)
and hasattr(original, "extras")
and isinstance(original.extras, dict)
and "defer_loading" in original.extras
):
formatted["defer_loading"] = original.extras["defer_loading"]
tool_names = []
for tool in formatted_tools:
if "function" in tool:
@@ -3981,7 +3990,8 @@ def _construct_responses_api_payload(
# chat api: {"type": "function", "function": {"name": "...", "description": "...", "parameters": {...}, "strict": ...}} # noqa: E501
# responses api: {"type": "function", "name": "...", "description": "...", "parameters": {...}, "strict": ...} # noqa: E501
if tool["type"] == "function" and "function" in tool:
new_tools.append({"type": "function", **tool["function"]})
extra = {k: v for k, v in tool.items() if k not in ("type", "function")}
new_tools.append({"type": "function", **tool["function"], **extra})
else:
if tool["type"] == "image_generation":
# Handle partial images (not yet supported)
@@ -4308,6 +4318,8 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
"mcp_call",
"mcp_list_tools",
"mcp_approval_request",
"tool_search_call",
"tool_search_output",
):
input_.append(_pop_index_and_sub_index(block))
elif block_type == "image_generation_call":
@@ -4353,7 +4365,7 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
elif msg["role"] in ("user", "system", "developer"):
if isinstance(msg["content"], list):
new_blocks = []
non_message_item_types = ("mcp_approval_response",)
non_message_item_types = ("mcp_approval_response", "tool_search_output")
for block in msg["content"]:
if block["type"] in ("text", "image_url", "file"):
new_blocks.append(
@@ -4510,6 +4522,8 @@ def _construct_lc_result_from_responses_api(
"mcp_list_tools",
"mcp_approval_request",
"image_generation_call",
"tool_search_call",
"tool_search_output",
):
content_blocks.append(output.model_dump(exclude_none=True, mode="json"))
@@ -4719,6 +4733,8 @@ def _convert_responses_chunk_to_generation_chunk(
"mcp_list_tools",
"mcp_approval_request",
"image_generation_call",
"tool_search_call",
"tool_search_output",
):
_advance(chunk.output_index)
tool_output = chunk.item.model_dump(exclude_none=True, mode="json")

View File

@@ -1,3 +1,4 @@
import json
from typing import Any
import pytest
@@ -30,6 +31,9 @@ def remove_response_headers(response: dict) -> dict:
def vcr_config() -> dict:
"""Extend the default configuration coming from langchain_tests."""
config = base_vcr_config()
config["match_on"] = [
m if m != "body" else "json_body" for m in config.get("match_on", [])
]
config.setdefault("filter_headers", []).extend(_EXTRA_HEADERS)
config["before_record_request"] = remove_request_headers
config["before_record_response"] = remove_response_headers
@@ -38,6 +42,24 @@ def vcr_config() -> dict:
return config
def _json_body_matcher(r1: Any, r2: Any) -> None:
"""Match request bodies as parsed JSON, ignoring key order."""
b1 = r1.body or b""
b2 = r2.body or b""
if isinstance(b1, bytes):
b1 = b1.decode("utf-8")
if isinstance(b2, bytes):
b2 = b2.decode("utf-8")
try:
j1 = json.loads(b1)
j2 = json.loads(b2)
except (json.JSONDecodeError, ValueError):
assert b1 == b2, f"body mismatch (non-JSON):\n{b1}\n!=\n{b2}"
return
assert j1 == j2, f"body mismatch:\n{j1}\n!=\n{j2}"
def pytest_recording_configure(config: dict, vcr: VCR) -> None:
vcr.register_persister(CustomPersister())
vcr.register_serializer("yaml.gz", CustomSerializer())
vcr.register_matcher("json_body", _json_body_matcher)

View File

@@ -7,6 +7,13 @@ from typing import Annotated, Any, Literal, cast
import openai
import pytest
from langchain.agents import create_agent
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ToolCallRequest,
hook_config,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
@@ -14,7 +21,10 @@ from langchain_core.messages import (
BaseMessageChunk,
HumanMessage,
MessageLikeRepresentation,
ToolMessage,
)
from langchain_core.tools import tool
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import BaseModel
from typing_extensions import TypedDict
@@ -193,6 +203,74 @@ def test_function_calling(output_version: Literal["v0", "responses/v1", "v1"]) -
_check_response(response)
@pytest.mark.default_cassette("test_agent_loop.yaml.gz")
@pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["responses/v1", "v1"])
def test_agent_loop(output_version: Literal["responses/v1", "v1"]) -> None:
@tool
def get_weather(location: str) -> str:
"""Get the weather for a location."""
return "It's sunny."
llm = ChatOpenAI(
model="gpt-5.4",
use_responses_api=True,
output_version=output_version,
)
llm_with_tools = llm.bind_tools([get_weather])
input_message = HumanMessage("What is the weather in San Francisco, CA?")
tool_call_message = llm_with_tools.invoke([input_message])
assert isinstance(tool_call_message, AIMessage)
tool_calls = tool_call_message.tool_calls
assert len(tool_calls) == 1
tool_call = tool_calls[0]
tool_message = get_weather.invoke(tool_call)
assert isinstance(tool_message, ToolMessage)
response = llm_with_tools.invoke(
[
input_message,
tool_call_message,
tool_message,
]
)
assert isinstance(response, AIMessage)
@pytest.mark.default_cassette("test_agent_loop_streaming.yaml.gz")
@pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["responses/v1", "v1"])
def test_agent_loop_streaming(output_version: Literal["responses/v1", "v1"]) -> None:
@tool
def get_weather(location: str) -> str:
"""Get the weather for a location."""
return "It's sunny."
llm = ChatOpenAI(
model="gpt-5.2",
use_responses_api=True,
reasoning={"effort": "medium", "summary": "auto"},
streaming=True,
output_version=output_version,
)
llm_with_tools = llm.bind_tools([get_weather])
input_message = HumanMessage("What is the weather in San Francisco, CA?")
tool_call_message = llm_with_tools.invoke([input_message])
assert isinstance(tool_call_message, AIMessage)
tool_calls = tool_call_message.tool_calls
assert len(tool_calls) == 1
tool_call = tool_calls[0]
tool_message = get_weather.invoke(tool_call)
assert isinstance(tool_message, ToolMessage)
response = llm_with_tools.invoke(
[
input_message,
tool_call_message,
tool_message,
]
)
assert isinstance(response, AIMessage)
class Foo(BaseModel):
response: str
@@ -1267,3 +1345,183 @@ def test_csv_input() -> None:
"3" in str(response2.content).lower()
or "three" in str(response2.content).lower()
)
@pytest.mark.default_cassette("test_tool_search.yaml.gz")
@pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["responses/v1", "v1"])
def test_tool_search(output_version: str) -> None:
@tool(extras={"defer_loading": True})
def get_weather(location: str) -> str:
"""Get the current weather for a location."""
return f"The weather in {location} is sunny and 72°F"
@tool(extras={"defer_loading": True})
def get_recipe(query: str) -> None:
"""Get a recipe for chicken soup."""
model = ChatOpenAI(
model="gpt-5.4",
use_responses_api=True,
output_version=output_version,
)
agent = create_agent(
model=model,
tools=[get_weather, get_recipe, {"type": "tool_search"}],
)
input_message = {"role": "user", "content": "What's the weather in San Francisco?"}
result = agent.invoke({"messages": [input_message]})
assert len(result["messages"]) == 4
tool_call_message = result["messages"][1]
assert isinstance(tool_call_message, AIMessage)
assert tool_call_message.tool_calls
if output_version == "v1":
assert [block["type"] for block in tool_call_message.content] == [ # type: ignore[index]
"server_tool_call",
"server_tool_result",
"tool_call",
]
else:
assert [block["type"] for block in tool_call_message.content] == [ # type: ignore[index]
"tool_search_call",
"tool_search_output",
"function_call",
]
assert isinstance(result["messages"][2], ToolMessage)
assert result["messages"][3].text
@pytest.mark.default_cassette("test_tool_search_streaming.yaml.gz")
@pytest.mark.vcr
@pytest.mark.parametrize("output_version", ["responses/v1", "v1"])
def test_tool_search_streaming(output_version: str) -> None:
@tool(extras={"defer_loading": True})
def get_weather(location: str) -> str:
"""Get the current weather for a location."""
return f"The weather in {location} is sunny and 72°F"
@tool(extras={"defer_loading": True})
def get_recipe(query: str) -> None:
"""Get a recipe for chicken soup."""
model = ChatOpenAI(
model="gpt-5.4",
use_responses_api=True,
streaming=True,
output_version=output_version,
)
agent = create_agent(
model=model,
tools=[get_weather, get_recipe, {"type": "tool_search"}],
)
input_message = {"role": "user", "content": "What's the weather in San Francisco?"}
result = agent.invoke({"messages": [input_message]})
assert len(result["messages"]) == 4
tool_call_message = result["messages"][1]
assert isinstance(tool_call_message, AIMessage)
assert tool_call_message.tool_calls
if output_version == "v1":
assert [block["type"] for block in tool_call_message.content] == [ # type: ignore[index]
"server_tool_call",
"server_tool_result",
"tool_call",
]
else:
assert [block["type"] for block in tool_call_message.content] == [ # type: ignore[index]
"tool_search_call",
"tool_search_output",
"function_call",
]
assert isinstance(result["messages"][2], ToolMessage)
assert result["messages"][3].text
@pytest.mark.vcr
def test_client_executed_tool_search() -> None:
@tool
def get_weather(location: str) -> str:
"""Get the current weather for a location."""
return f"The weather in {location} is sunny and 72°F"
def search_tools(goal: str) -> list[dict]:
"""Search for available tools to help answer the question."""
return [
{
"type": "function",
"defer_loading": True,
**convert_to_openai_tool(get_weather)["function"],
}
]
tool_search_schema = convert_to_openai_tool(search_tools, strict=True)
tool_search_config: dict = {
"type": "tool_search",
"execution": "client",
"description": tool_search_schema["function"]["description"],
"parameters": tool_search_schema["function"]["parameters"],
}
class ClientToolSearchMiddleware(AgentMiddleware):
@hook_config(can_jump_to=["model"])
def after_model(self, state: AgentState, runtime: Any) -> dict[str, Any] | None:
last_message = state["messages"][-1]
if not isinstance(last_message, AIMessage):
return None
for block in last_message.content:
if isinstance(block, dict) and block.get("type") == "tool_search_call":
call_id = block.get("call_id")
args = block.get("arguments", {})
goal = args.get("goal", "") if isinstance(args, dict) else ""
loaded_tools = search_tools(goal)
tool_search_output = {
"type": "tool_search_output",
"execution": "client",
"call_id": call_id,
"status": "completed",
"tools": loaded_tools,
}
return {
"messages": [HumanMessage(content=[tool_search_output])],
"jump_to": "model",
}
return None
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Any,
) -> Any:
if request.tool_call["name"] == "get_weather":
return handler(request.override(tool=get_weather))
return handler(request)
llm = ChatOpenAI(model="gpt-5.4", use_responses_api=True)
agent = create_agent(
model=llm,
tools=[tool_search_config],
middleware=[ClientToolSearchMiddleware()],
)
result = agent.invoke(
{"messages": [HumanMessage("What's the weather in San Francisco?")]}
)
messages = result["messages"]
search_tool_call = messages[1]
assert search_tool_call.content[0]["type"] == "tool_search_call"
search_tool_output = messages[2]
assert search_tool_output.content[0]["type"] == "tool_search_output"
tool_call = messages[3]
assert tool_call.tool_calls
assert isinstance(messages[4], ToolMessage)
assert messages[5].text

View File

@@ -2787,13 +2787,13 @@ def test_convert_from_v1_to_chat_completions(
"type": "function_call",
"call_id": "call_123",
"name": "get_weather",
"arguments": '{"location": "San Francisco"}',
"arguments": '{"location":"San Francisco"}',
},
{
"type": "function_call",
"call_id": "call_234",
"name": "get_weather_2",
"arguments": '{"location": "New York"}',
"arguments": '{"location":"New York"}',
"id": "fc_123",
},
{"type": "text", "text": "Hello "},
@@ -3474,3 +3474,113 @@ def test_context_overflow_error_backwards_compatibility() -> None:
# Verify it's both types (multiple inheritance)
assert isinstance(exc_info.value, openai.BadRequestError)
assert isinstance(exc_info.value, ContextOverflowError)
def test_tool_search_passthrough() -> None:
"""Test that tool_search dict is passed through as a built-in tool."""
llm = ChatOpenAI(model="gpt-4o")
tool_search = {"type": "tool_search"}
bound = llm.bind_tools([tool_search])
payload = bound._get_request_payload( # type: ignore[attr-defined]
"test",
**bound.kwargs, # type: ignore[attr-defined]
)
assert {"type": "tool_search"} in payload["tools"]
assert "input" in payload
def test_tool_search_with_defer_loading_extras() -> None:
"""Test that defer_loading from BaseTool extras is merged into tool defs."""
from langchain_core.tools import tool
@tool(extras={"defer_loading": True})
def get_weather(location: str) -> str:
"""Get weather for a location."""
return f"Weather in {location}"
llm = ChatOpenAI(model="gpt-4o")
bound = llm.bind_tools([get_weather, {"type": "tool_search"}])
payload = bound._get_request_payload( # type: ignore[attr-defined]
"test",
**bound.kwargs, # type: ignore[attr-defined]
)
weather_tool = None
for t in payload["tools"]:
if t.get("type") == "function" and t.get("name") == "get_weather":
weather_tool = t
break
assert weather_tool is not None
assert weather_tool["defer_loading"] is True
assert {"type": "tool_search"} in payload["tools"]
def test_namespace_passthrough() -> None:
"""Test that namespace tool dicts are passed through unchanged."""
llm = ChatOpenAI(model="gpt-4o")
namespace_tool = {
"type": "namespace",
"name": "crm",
"description": "CRM tools.",
"tools": [
{
"type": "function",
"name": "list_orders",
"description": "List orders.",
"defer_loading": True,
"parameters": {
"type": "object",
"properties": {"customer_id": {"type": "string"}},
"required": ["customer_id"],
},
}
],
}
bound = llm.bind_tools([namespace_tool, {"type": "tool_search"}])
payload = bound._get_request_payload( # type: ignore[attr-defined]
"test",
**bound.kwargs, # type: ignore[attr-defined]
)
ns = None
for t in payload["tools"]:
if t.get("type") == "namespace":
ns = t
break
assert ns is not None
assert ns["name"] == "crm"
assert ns["tools"][0]["defer_loading"] is True
assert {"type": "tool_search"} in payload["tools"]
def test_defer_loading_in_responses_api_payload() -> None:
"""Test that defer_loading is preserved in Responses API tool format."""
from langchain_openai.chat_models.base import _construct_responses_api_payload
messages: list = []
payload = {
"model": "gpt-4o",
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather.",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
"defer_loading": True,
},
{"type": "tool_search"},
],
}
result = _construct_responses_api_payload(messages, payload)
weather_tool = None
for t in result["tools"]:
if t.get("name") == "get_weather":
weather_tool = t
break
assert weather_tool is not None
assert weather_tool["defer_loading"] is True
assert weather_tool["type"] == "function"
assert {"type": "tool_search"} in result["tools"]

View File

@@ -1,5 +1,5 @@
version = 1
revision = 3
revision = 2
requires-python = ">=3.10.0, <4.0.0"
resolution-markers = [
"python_full_version >= '3.13' and platform_python_implementation == 'PyPy'",
@@ -1105,7 +1105,7 @@ wheels = [
[[package]]
name = "openai"
version = "2.21.0"
version = "2.26.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
@@ -1117,9 +1117,9 @@ dependencies = [
{ name = "tqdm" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/92/e5/3d197a0947a166649f566706d7a4c8f7fe38f1fa7b24c9bcffe4c7591d44/openai-2.21.0.tar.gz", hash = "sha256:81b48ce4b8bbb2cc3af02047ceb19561f7b1dc0d4e52d1de7f02abfd15aa59b7", size = 644374, upload-time = "2026-02-14T00:12:01.577Z" }
sdist = { url = "https://files.pythonhosted.org/packages/d7/91/2a06c4e9597c338cac1e5e5a8dd6f29e1836fc229c4c523529dca387fda8/openai-2.26.0.tar.gz", hash = "sha256:b41f37c140ae0034a6e92b0c509376d907f3a66109935fba2c1b471a7c05a8fb", size = 666702, upload-time = "2026-03-05T23:17:35.874Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/cc/56/0a89092a453bb2c676d66abee44f863e742b2110d4dbb1dbcca3f7e5fc33/openai-2.21.0-py3-none-any.whl", hash = "sha256:0bc1c775e5b1536c294eded39ee08f8407656537ccc71b1004104fe1602e267c", size = 1103065, upload-time = "2026-02-14T00:11:59.603Z" },
{ url = "https://files.pythonhosted.org/packages/c6/2e/3f73e8ca53718952222cacd0cf7eecc9db439d020f0c1fe7ae717e4e199a/openai-2.26.0-py3-none-any.whl", hash = "sha256:6151bf8f83802f036117f06cc8a57b3a4da60da9926826cc96747888b57f394f", size = 1136409, upload-time = "2026-03-05T23:17:34.072Z" },
]
[[package]]