feat(openai): support tool search (#35582)

This commit is contained in:
ccurme
2026-03-08 08:53:13 -04:00
committed by GitHub
parent 532b014f5c
commit fbfe4b812d
13 changed files with 514 additions and 13 deletions

View File

@@ -103,6 +103,8 @@ def _convert_to_v03_ai_message(
"mcp_list_tools",
"mcp_approval_request",
"image_generation_call",
"tool_search_call",
"tool_search_output",
):
# Store built-in tool calls in additional_kwargs
if "tool_outputs" not in message.additional_kwargs:
@@ -420,17 +422,58 @@ def _convert_from_v1_to_responses(
new_block["name"] = block["name"]
if "extras" in block and "arguments" in block["extras"]:
new_block["arguments"] = block["extras"]["arguments"]
if any(key not in block for key in ("name", "arguments")):
if any(key not in new_block for key in ("name", "arguments")):
matching_tool_calls = [
call for call in tool_calls if call["id"] == block["id"]
]
if matching_tool_calls:
tool_call = matching_tool_calls[0]
if "name" not in block:
if "name" not in new_block:
new_block["name"] = tool_call["name"]
if "arguments" not in block:
new_block["arguments"] = json.dumps(tool_call["args"])
if "arguments" not in new_block:
new_block["arguments"] = json.dumps(
tool_call["args"], separators=(",", ":")
)
if "extras" in block:
for extra_key in ("status", "namespace"):
if extra_key in block["extras"]:
new_block[extra_key] = block["extras"][extra_key]
new_content.append(new_block)
elif block["type"] == "server_tool_call" and block.get("name") == "tool_search":
extras = block.get("extras", {})
new_block = {"id": block["id"]}
status = extras.get("status")
if status:
new_block["status"] = status
new_block["type"] = "tool_search_call"
if "args" in block:
new_block["arguments"] = block["args"]
execution = extras.get("execution")
if execution:
new_block["execution"] = execution
new_content.append(new_block)
elif (
block["type"] == "server_tool_result"
and block.get("extras", {}).get("name") == "tool_search"
):
extras = block.get("extras", {})
new_block = {"id": block.get("tool_call_id", "")}
status = block.get("status")
if status == "success":
new_block["status"] = "completed"
elif status == "error":
new_block["status"] = "failed"
elif status:
new_block["status"] = status
new_block["type"] = "tool_search_output"
new_block["execution"] = "server"
output: dict = block.get("output", {})
if isinstance(output, dict) and "tools" in output:
new_block["tools"] = output["tools"]
new_content.append(new_block)
elif (
is_data_content_block(cast(dict, block))
and block["type"] == "image"
@@ -441,7 +484,7 @@ def _convert_from_v1_to_responses(
new_block = {"type": "image_generation_call", "result": block["base64"]}
for extra_key in ("id", "status"):
if extra_key in block:
new_block[extra_key] = block[extra_key] # type: ignore[typeddict-item]
new_block[extra_key] = block[extra_key] # type: ignore[literal-required]
elif extra_key in block.get("extras", {}):
new_block[extra_key] = block["extras"][extra_key]
new_content.append(new_block)

View File

@@ -166,6 +166,7 @@ WellKnownTools = (
"code_interpreter",
"mcp",
"image_generation",
"tool_search",
)
@@ -1984,6 +1985,14 @@ class BaseChatOpenAI(BaseChatModel):
formatted_tools = [
convert_to_openai_tool(tool, strict=strict) for tool in tools
]
for original, formatted in zip(tools, formatted_tools, strict=False):
if (
isinstance(original, BaseTool)
and hasattr(original, "extras")
and isinstance(original.extras, dict)
and "defer_loading" in original.extras
):
formatted["defer_loading"] = original.extras["defer_loading"]
tool_names = []
for tool in formatted_tools:
if "function" in tool:
@@ -3981,7 +3990,8 @@ def _construct_responses_api_payload(
# chat api: {"type": "function", "function": {"name": "...", "description": "...", "parameters": {...}, "strict": ...}} # noqa: E501
# responses api: {"type": "function", "name": "...", "description": "...", "parameters": {...}, "strict": ...} # noqa: E501
if tool["type"] == "function" and "function" in tool:
new_tools.append({"type": "function", **tool["function"]})
extra = {k: v for k, v in tool.items() if k not in ("type", "function")}
new_tools.append({"type": "function", **tool["function"], **extra})
else:
if tool["type"] == "image_generation":
# Handle partial images (not yet supported)
@@ -4308,6 +4318,8 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
"mcp_call",
"mcp_list_tools",
"mcp_approval_request",
"tool_search_call",
"tool_search_output",
):
input_.append(_pop_index_and_sub_index(block))
elif block_type == "image_generation_call":
@@ -4353,7 +4365,7 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
elif msg["role"] in ("user", "system", "developer"):
if isinstance(msg["content"], list):
new_blocks = []
non_message_item_types = ("mcp_approval_response",)
non_message_item_types = ("mcp_approval_response", "tool_search_output")
for block in msg["content"]:
if block["type"] in ("text", "image_url", "file"):
new_blocks.append(
@@ -4510,6 +4522,8 @@ def _construct_lc_result_from_responses_api(
"mcp_list_tools",
"mcp_approval_request",
"image_generation_call",
"tool_search_call",
"tool_search_output",
):
content_blocks.append(output.model_dump(exclude_none=True, mode="json"))
@@ -4719,6 +4733,8 @@ def _convert_responses_chunk_to_generation_chunk(
"mcp_list_tools",
"mcp_approval_request",
"image_generation_call",
"tool_search_call",
"tool_search_output",
):
_advance(chunk.output_index)
tool_output = chunk.item.model_dump(exclude_none=True, mode="json")