mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
feat(openai): support tool search (#35582)
This commit is contained in:
@@ -103,6 +103,8 @@ def _convert_to_v03_ai_message(
|
||||
"mcp_list_tools",
|
||||
"mcp_approval_request",
|
||||
"image_generation_call",
|
||||
"tool_search_call",
|
||||
"tool_search_output",
|
||||
):
|
||||
# Store built-in tool calls in additional_kwargs
|
||||
if "tool_outputs" not in message.additional_kwargs:
|
||||
@@ -420,17 +422,58 @@ def _convert_from_v1_to_responses(
|
||||
new_block["name"] = block["name"]
|
||||
if "extras" in block and "arguments" in block["extras"]:
|
||||
new_block["arguments"] = block["extras"]["arguments"]
|
||||
if any(key not in block for key in ("name", "arguments")):
|
||||
if any(key not in new_block for key in ("name", "arguments")):
|
||||
matching_tool_calls = [
|
||||
call for call in tool_calls if call["id"] == block["id"]
|
||||
]
|
||||
if matching_tool_calls:
|
||||
tool_call = matching_tool_calls[0]
|
||||
if "name" not in block:
|
||||
if "name" not in new_block:
|
||||
new_block["name"] = tool_call["name"]
|
||||
if "arguments" not in block:
|
||||
new_block["arguments"] = json.dumps(tool_call["args"])
|
||||
if "arguments" not in new_block:
|
||||
new_block["arguments"] = json.dumps(
|
||||
tool_call["args"], separators=(",", ":")
|
||||
)
|
||||
if "extras" in block:
|
||||
for extra_key in ("status", "namespace"):
|
||||
if extra_key in block["extras"]:
|
||||
new_block[extra_key] = block["extras"][extra_key]
|
||||
new_content.append(new_block)
|
||||
|
||||
elif block["type"] == "server_tool_call" and block.get("name") == "tool_search":
|
||||
extras = block.get("extras", {})
|
||||
new_block = {"id": block["id"]}
|
||||
status = extras.get("status")
|
||||
if status:
|
||||
new_block["status"] = status
|
||||
new_block["type"] = "tool_search_call"
|
||||
if "args" in block:
|
||||
new_block["arguments"] = block["args"]
|
||||
execution = extras.get("execution")
|
||||
if execution:
|
||||
new_block["execution"] = execution
|
||||
new_content.append(new_block)
|
||||
|
||||
elif (
|
||||
block["type"] == "server_tool_result"
|
||||
and block.get("extras", {}).get("name") == "tool_search"
|
||||
):
|
||||
extras = block.get("extras", {})
|
||||
new_block = {"id": block.get("tool_call_id", "")}
|
||||
status = block.get("status")
|
||||
if status == "success":
|
||||
new_block["status"] = "completed"
|
||||
elif status == "error":
|
||||
new_block["status"] = "failed"
|
||||
elif status:
|
||||
new_block["status"] = status
|
||||
new_block["type"] = "tool_search_output"
|
||||
new_block["execution"] = "server"
|
||||
output: dict = block.get("output", {})
|
||||
if isinstance(output, dict) and "tools" in output:
|
||||
new_block["tools"] = output["tools"]
|
||||
new_content.append(new_block)
|
||||
|
||||
elif (
|
||||
is_data_content_block(cast(dict, block))
|
||||
and block["type"] == "image"
|
||||
@@ -441,7 +484,7 @@ def _convert_from_v1_to_responses(
|
||||
new_block = {"type": "image_generation_call", "result": block["base64"]}
|
||||
for extra_key in ("id", "status"):
|
||||
if extra_key in block:
|
||||
new_block[extra_key] = block[extra_key] # type: ignore[typeddict-item]
|
||||
new_block[extra_key] = block[extra_key] # type: ignore[literal-required]
|
||||
elif extra_key in block.get("extras", {}):
|
||||
new_block[extra_key] = block["extras"][extra_key]
|
||||
new_content.append(new_block)
|
||||
|
||||
@@ -166,6 +166,7 @@ WellKnownTools = (
|
||||
"code_interpreter",
|
||||
"mcp",
|
||||
"image_generation",
|
||||
"tool_search",
|
||||
)
|
||||
|
||||
|
||||
@@ -1984,6 +1985,14 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
formatted_tools = [
|
||||
convert_to_openai_tool(tool, strict=strict) for tool in tools
|
||||
]
|
||||
for original, formatted in zip(tools, formatted_tools, strict=False):
|
||||
if (
|
||||
isinstance(original, BaseTool)
|
||||
and hasattr(original, "extras")
|
||||
and isinstance(original.extras, dict)
|
||||
and "defer_loading" in original.extras
|
||||
):
|
||||
formatted["defer_loading"] = original.extras["defer_loading"]
|
||||
tool_names = []
|
||||
for tool in formatted_tools:
|
||||
if "function" in tool:
|
||||
@@ -3981,7 +3990,8 @@ def _construct_responses_api_payload(
|
||||
# chat api: {"type": "function", "function": {"name": "...", "description": "...", "parameters": {...}, "strict": ...}} # noqa: E501
|
||||
# responses api: {"type": "function", "name": "...", "description": "...", "parameters": {...}, "strict": ...} # noqa: E501
|
||||
if tool["type"] == "function" and "function" in tool:
|
||||
new_tools.append({"type": "function", **tool["function"]})
|
||||
extra = {k: v for k, v in tool.items() if k not in ("type", "function")}
|
||||
new_tools.append({"type": "function", **tool["function"], **extra})
|
||||
else:
|
||||
if tool["type"] == "image_generation":
|
||||
# Handle partial images (not yet supported)
|
||||
@@ -4308,6 +4318,8 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
|
||||
"mcp_call",
|
||||
"mcp_list_tools",
|
||||
"mcp_approval_request",
|
||||
"tool_search_call",
|
||||
"tool_search_output",
|
||||
):
|
||||
input_.append(_pop_index_and_sub_index(block))
|
||||
elif block_type == "image_generation_call":
|
||||
@@ -4353,7 +4365,7 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
|
||||
elif msg["role"] in ("user", "system", "developer"):
|
||||
if isinstance(msg["content"], list):
|
||||
new_blocks = []
|
||||
non_message_item_types = ("mcp_approval_response",)
|
||||
non_message_item_types = ("mcp_approval_response", "tool_search_output")
|
||||
for block in msg["content"]:
|
||||
if block["type"] in ("text", "image_url", "file"):
|
||||
new_blocks.append(
|
||||
@@ -4510,6 +4522,8 @@ def _construct_lc_result_from_responses_api(
|
||||
"mcp_list_tools",
|
||||
"mcp_approval_request",
|
||||
"image_generation_call",
|
||||
"tool_search_call",
|
||||
"tool_search_output",
|
||||
):
|
||||
content_blocks.append(output.model_dump(exclude_none=True, mode="json"))
|
||||
|
||||
@@ -4719,6 +4733,8 @@ def _convert_responses_chunk_to_generation_chunk(
|
||||
"mcp_list_tools",
|
||||
"mcp_approval_request",
|
||||
"image_generation_call",
|
||||
"tool_search_call",
|
||||
"tool_search_output",
|
||||
):
|
||||
_advance(chunk.output_index)
|
||||
tool_output = chunk.item.model_dump(exclude_none=True, mode="json")
|
||||
|
||||
Reference in New Issue
Block a user