feat(openai): custom tools (#32449)

This commit is contained in:
ccurme
2025-08-07 17:30:01 -03:00
committed by GitHub
parent 145d38f7dd
commit ec2b34a02d
11 changed files with 488 additions and 4 deletions

View File

@@ -1,6 +1,7 @@
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain_openai.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
from langchain_openai.llms import AzureOpenAI, OpenAI
from langchain_openai.tools import custom_tool
__all__ = [
"OpenAI",
@@ -9,4 +10,5 @@ __all__ = [
"AzureOpenAI",
"AzureChatOpenAI",
"AzureOpenAIEmbeddings",
"custom_tool",
]

View File

@@ -3582,6 +3582,20 @@ def _make_computer_call_output_from_message(message: ToolMessage) -> dict:
return computer_call_output
def _make_custom_tool_output_from_message(message: ToolMessage) -> Optional[dict]:
custom_tool_output = None
for block in message.content:
if isinstance(block, dict) and block.get("type") == "custom_tool_call_output":
custom_tool_output = {
"type": "custom_tool_call_output",
"call_id": message.tool_call_id,
"output": block.get("output") or "",
}
break
return custom_tool_output
def _pop_index_and_sub_index(block: dict) -> dict:
"""When streaming, langchain-core uses the ``index`` key to aggregate
text blocks. OpenAI API does not support this key, so we need to remove it.
@@ -3608,7 +3622,10 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
msg.pop("name")
if msg["role"] == "tool":
tool_output = msg["content"]
if lc_msg.additional_kwargs.get("type") == "computer_call_output":
custom_tool_output = _make_custom_tool_output_from_message(lc_msg) # type: ignore[arg-type]
if custom_tool_output:
input_.append(custom_tool_output)
elif lc_msg.additional_kwargs.get("type") == "computer_call_output":
computer_call_output = _make_computer_call_output_from_message(
cast(ToolMessage, lc_msg)
)
@@ -3663,6 +3680,7 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
"file_search_call",
"function_call",
"computer_call",
"custom_tool_call",
"code_interpreter_call",
"mcp_call",
"mcp_list_tools",
@@ -3690,7 +3708,8 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
content_call_ids = {
block["call_id"]
for block in input_
if block.get("type") == "function_call" and "call_id" in block
if block.get("type") in ("function_call", "custom_tool_call")
and "call_id" in block
}
for tool_call in tool_calls:
if tool_call["id"] not in content_call_ids:
@@ -3841,6 +3860,15 @@ def _construct_lc_result_from_responses_api(
"error": error,
}
invalid_tool_calls.append(tool_call)
elif output.type == "custom_tool_call":
content_blocks.append(output.model_dump(exclude_none=True, mode="json"))
tool_call = {
"type": "tool_call",
"name": output.name,
"args": {"__arg1": output.input},
"id": output.call_id,
}
tool_calls.append(tool_call)
elif output.type in (
"reasoning",
"web_search_call",
@@ -4044,6 +4072,23 @@ def _convert_responses_chunk_to_generation_chunk(
tool_output = chunk.item.model_dump(exclude_none=True, mode="json")
tool_output["index"] = current_index
content.append(tool_output)
elif (
chunk.type == "response.output_item.done"
and chunk.item.type == "custom_tool_call"
):
_advance(chunk.output_index)
tool_output = chunk.item.model_dump(exclude_none=True, mode="json")
tool_output["index"] = current_index
content.append(tool_output)
tool_call_chunks.append(
{
"type": "tool_call_chunk",
"name": chunk.item.name,
"args": json.dumps({"__arg1": chunk.item.input}),
"id": chunk.item.call_id,
"index": current_index,
}
)
elif chunk.type == "response.function_call_arguments.delta":
_advance(chunk.output_index)
tool_call_chunks.append(

View File

@@ -0,0 +1,3 @@
from langchain_openai.tools.custom_tool import custom_tool
__all__ = ["custom_tool"]

View File

@@ -0,0 +1,109 @@
import inspect
from collections.abc import Awaitable
from typing import Any, Callable
from langchain_core.tools import tool
def _make_wrapped_func(func: Callable[..., str]) -> Callable[..., list[dict[str, Any]]]:
def wrapped(x: str) -> list[dict[str, Any]]:
return [{"type": "custom_tool_call_output", "output": func(x)}]
return wrapped
def _make_wrapped_coroutine(
coroutine: Callable[..., Awaitable[str]],
) -> Callable[..., Awaitable[list[dict[str, Any]]]]:
async def wrapped(*args: Any, **kwargs: Any) -> list[dict[str, Any]]:
result = await coroutine(*args, **kwargs)
return [{"type": "custom_tool_call_output", "output": result}]
return wrapped
def custom_tool(*args: Any, **kwargs: Any) -> Any:
"""Decorator to create an OpenAI custom tool.
Custom tools allow for tools with (potentially long) freeform string inputs.
See below for an example using LangGraph:
.. code-block:: python
@custom_tool
def execute_code(code: str) -> str:
\"\"\"Execute python code.\"\"\"
return "27"
llm = ChatOpenAI(model="gpt-5", output_version="responses/v1")
agent = create_react_agent(llm, [execute_code])
input_message = {"role": "user", "content": "Use the tool to calculate 3^3."}
for step in agent.stream(
{"messages": [input_message]},
stream_mode="values",
):
step["messages"][-1].pretty_print()
You can also specify a format for a corresponding context-free grammar using the
``format`` kwarg:
.. code-block:: python
from langchain_openai import ChatOpenAI, custom_tool
from langgraph.prebuilt import create_react_agent
grammar = \"\"\"
start: expr
expr: term (SP ADD SP term)* -> add
| term
term: factor (SP MUL SP factor)* -> mul
| factor
factor: INT
SP: " "
ADD: "+"
MUL: "*"
%import common.INT
\"\"\"
format = {"type": "grammar", "syntax": "lark", "definition": grammar}
# highlight-next-line
@custom_tool(format=format)
def do_math(input_string: str) -> str:
\"\"\"Do a mathematical operation.\"\"\"
return "27"
llm = ChatOpenAI(model="gpt-5", output_version="responses/v1")
agent = create_react_agent(llm, [do_math])
input_message = {"role": "user", "content": "Use the tool to calculate 3^3."}
for step in agent.stream(
{"messages": [input_message]},
stream_mode="values",
):
step["messages"][-1].pretty_print()
"""
def decorator(func: Callable[..., Any]) -> Any:
metadata = {"type": "custom_tool"}
if "format" in kwargs:
metadata["format"] = kwargs.pop("format")
tool_obj = tool(infer_schema=False, **kwargs)(func)
tool_obj.metadata = metadata
tool_obj.description = func.__doc__
if inspect.iscoroutinefunction(func):
tool_obj.coroutine = _make_wrapped_coroutine(func)
else:
tool_obj.func = _make_wrapped_func(func)
return tool_obj
if args and callable(args[0]) and not kwargs:
return decorator(args[0])
return decorator