feat(openai): custom tools (#32449)

This commit is contained in:
ccurme 2025-08-07 17:30:01 -03:00 committed by GitHub
parent 145d38f7dd
commit ec2b34a02d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 488 additions and 4 deletions

View File

@ -447,6 +447,163 @@
")"
]
},
{
"cell_type": "markdown",
"id": "c5d9d19d-8ab1-4d9d-b3a0-56ee4e89c528",
"metadata": {},
"source": [
"### Custom tools\n",
"\n",
":::info Requires ``langchain-openai>=0.3.29``\n",
"\n",
":::\n",
"\n",
"[Custom tools](https://platform.openai.com/docs/guides/function-calling#custom-tools) support tools with arbitrary string inputs. They can be particularly useful when you expect your string arguments to be long or complex."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "a47c809b-852f-46bd-8b9e-d9534c17213d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"================================\u001b[1m Human Message \u001b[0m=================================\n",
"\n",
"Use the tool to calculate 3^3.\n",
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
"\n",
"[{'id': 'rs_6894ff5747c0819d9b02fc5645b0be9c000169fd9fb68d99', 'summary': [], 'type': 'reasoning'}, {'call_id': 'call_7SYwMSQPbbEqFcKlKOpXeEux', 'input': 'print(3**3)', 'name': 'execute_code', 'type': 'custom_tool_call', 'id': 'ctc_6894ff5b9f54819d8155a63638d34103000169fd9fb68d99', 'status': 'completed'}]\n",
"Tool Calls:\n",
" execute_code (call_7SYwMSQPbbEqFcKlKOpXeEux)\n",
" Call ID: call_7SYwMSQPbbEqFcKlKOpXeEux\n",
" Args:\n",
" __arg1: print(3**3)\n",
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
"Name: execute_code\n",
"\n",
"[{'type': 'custom_tool_call_output', 'output': '27'}]\n",
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
"\n",
"[{'type': 'text', 'text': '27', 'annotations': [], 'id': 'msg_6894ff5db3b8819d9159b3a370a25843000169fd9fb68d99'}]\n"
]
}
],
"source": [
"from langchain_openai import ChatOpenAI, custom_tool\n",
"from langgraph.prebuilt import create_react_agent\n",
"\n",
"\n",
"@custom_tool\n",
"def execute_code(code: str) -> str:\n",
" \"\"\"Execute python code.\"\"\"\n",
" return \"27\"\n",
"\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-5\", output_version=\"responses/v1\")\n",
"\n",
"agent = create_react_agent(llm, [execute_code])\n",
"\n",
"input_message = {\"role\": \"user\", \"content\": \"Use the tool to calculate 3^3.\"}\n",
"for step in agent.stream(\n",
" {\"messages\": [input_message]},\n",
" stream_mode=\"values\",\n",
"):\n",
" step[\"messages\"][-1].pretty_print()"
]
},
{
"cell_type": "markdown",
"id": "5ef93be6-6d4c-4eea-acfd-248774074082",
"metadata": {},
"source": [
"<details>\n",
"<summary>Context-free grammars</summary>\n",
"\n",
"OpenAI supports the specification of a [context-free grammar](https://platform.openai.com/docs/guides/function-calling#context-free-grammars) for custom tool inputs in `lark` or `regex` format. See [OpenAI docs](https://platform.openai.com/docs/guides/function-calling#context-free-grammars) for details. The `format` parameter can be passed into `@custom_tool` as shown below:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "2ae04586-be33-49c6-8947-7867801d868f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"================================\u001b[1m Human Message \u001b[0m=================================\n",
"\n",
"Use the tool to calculate 3^3.\n",
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
"\n",
"[{'id': 'rs_689500828a8481a297ff0f98e328689c0681550c89797f43', 'summary': [], 'type': 'reasoning'}, {'call_id': 'call_jzH01RVhu6EFz7yUrOFXX55s', 'input': '3 * 3 * 3', 'name': 'do_math', 'type': 'custom_tool_call', 'id': 'ctc_6895008d57bc81a2b84d0993517a66b90681550c89797f43', 'status': 'completed'}]\n",
"Tool Calls:\n",
" do_math (call_jzH01RVhu6EFz7yUrOFXX55s)\n",
" Call ID: call_jzH01RVhu6EFz7yUrOFXX55s\n",
" Args:\n",
" __arg1: 3 * 3 * 3\n",
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
"Name: do_math\n",
"\n",
"[{'type': 'custom_tool_call_output', 'output': '27'}]\n",
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
"\n",
"[{'type': 'text', 'text': '27', 'annotations': [], 'id': 'msg_6895009776b881a2a25f0be8507d08f20681550c89797f43'}]\n"
]
}
],
"source": [
"from langchain_openai import ChatOpenAI, custom_tool\n",
"from langgraph.prebuilt import create_react_agent\n",
"\n",
"grammar = \"\"\"\n",
"start: expr\n",
"expr: term (SP ADD SP term)* -> add\n",
"| term\n",
"term: factor (SP MUL SP factor)* -> mul\n",
"| factor\n",
"factor: INT\n",
"SP: \" \"\n",
"ADD: \"+\"\n",
"MUL: \"*\"\n",
"%import common.INT\n",
"\"\"\"\n",
"\n",
"format_ = {\"type\": \"grammar\", \"syntax\": \"lark\", \"definition\": grammar}\n",
"\n",
"\n",
"# highlight-next-line\n",
"@custom_tool(format=format_)\n",
"def do_math(input_string: str) -> str:\n",
" \"\"\"Do a mathematical operation.\"\"\"\n",
" return \"27\"\n",
"\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-5\", output_version=\"responses/v1\")\n",
"\n",
"agent = create_react_agent(llm, [do_math])\n",
"\n",
"input_message = {\"role\": \"user\", \"content\": \"Use the tool to calculate 3^3.\"}\n",
"for step in agent.stream(\n",
" {\"messages\": [input_message]},\n",
" stream_mode=\"values\",\n",
"):\n",
" step[\"messages\"][-1].pretty_print()"
]
},
{
"cell_type": "markdown",
"id": "c63430c9-c7b0-4e92-a491-3f165dddeb8f",
"metadata": {},
"source": [
"</details>"
]
},
{
"cell_type": "markdown",
"id": "84833dd0-17e9-4269-82ed-550639d65751",

View File

@ -74,7 +74,14 @@ if TYPE_CHECKING:
from collections.abc import Sequence
FILTERED_ARGS = ("run_manager", "callbacks")
TOOL_MESSAGE_BLOCK_TYPES = ("text", "image_url", "image", "json", "search_result")
TOOL_MESSAGE_BLOCK_TYPES = (
"text",
"image_url",
"image",
"json",
"search_result",
"custom_tool_call_output",
)
class SchemaAnnotationError(TypeError):

View File

@ -575,12 +575,23 @@ def convert_to_openai_tool(
Added support for OpenAI's image generation built-in tool.
"""
from langchain_core.tools import Tool
if isinstance(tool, dict):
if tool.get("type") in _WellKnownOpenAITools:
return tool
# As of 03.12.25 can be "web_search_preview" or "web_search_preview_2025_03_11"
if (tool.get("type") or "").startswith("web_search_preview"):
return tool
if isinstance(tool, Tool) and (tool.metadata or {}).get("type") == "custom_tool":
oai_tool = {
"type": "custom",
"name": tool.name,
"description": tool.description,
}
if tool.metadata is not None and "format" in tool.metadata:
oai_tool["format"] = tool.metadata["format"]
return oai_tool
oai_function = convert_to_openai_function(tool, strict=strict)
return {"type": "function", "function": oai_function}

View File

@ -1,6 +1,7 @@
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain_openai.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
from langchain_openai.llms import AzureOpenAI, OpenAI
from langchain_openai.tools import custom_tool
__all__ = [
"OpenAI",
@ -9,4 +10,5 @@ __all__ = [
"AzureOpenAI",
"AzureChatOpenAI",
"AzureOpenAIEmbeddings",
"custom_tool",
]

View File

@ -3582,6 +3582,20 @@ def _make_computer_call_output_from_message(message: ToolMessage) -> dict:
return computer_call_output
def _make_custom_tool_output_from_message(message: ToolMessage) -> Optional[dict]:
custom_tool_output = None
for block in message.content:
if isinstance(block, dict) and block.get("type") == "custom_tool_call_output":
custom_tool_output = {
"type": "custom_tool_call_output",
"call_id": message.tool_call_id,
"output": block.get("output") or "",
}
break
return custom_tool_output
def _pop_index_and_sub_index(block: dict) -> dict:
"""When streaming, langchain-core uses the ``index`` key to aggregate
text blocks. OpenAI API does not support this key, so we need to remove it.
@ -3608,7 +3622,10 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
msg.pop("name")
if msg["role"] == "tool":
tool_output = msg["content"]
if lc_msg.additional_kwargs.get("type") == "computer_call_output":
custom_tool_output = _make_custom_tool_output_from_message(lc_msg) # type: ignore[arg-type]
if custom_tool_output:
input_.append(custom_tool_output)
elif lc_msg.additional_kwargs.get("type") == "computer_call_output":
computer_call_output = _make_computer_call_output_from_message(
cast(ToolMessage, lc_msg)
)
@ -3663,6 +3680,7 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
"file_search_call",
"function_call",
"computer_call",
"custom_tool_call",
"code_interpreter_call",
"mcp_call",
"mcp_list_tools",
@ -3690,7 +3708,8 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
content_call_ids = {
block["call_id"]
for block in input_
if block.get("type") == "function_call" and "call_id" in block
if block.get("type") in ("function_call", "custom_tool_call")
and "call_id" in block
}
for tool_call in tool_calls:
if tool_call["id"] not in content_call_ids:
@ -3841,6 +3860,15 @@ def _construct_lc_result_from_responses_api(
"error": error,
}
invalid_tool_calls.append(tool_call)
elif output.type == "custom_tool_call":
content_blocks.append(output.model_dump(exclude_none=True, mode="json"))
tool_call = {
"type": "tool_call",
"name": output.name,
"args": {"__arg1": output.input},
"id": output.call_id,
}
tool_calls.append(tool_call)
elif output.type in (
"reasoning",
"web_search_call",
@ -4044,6 +4072,23 @@ def _convert_responses_chunk_to_generation_chunk(
tool_output = chunk.item.model_dump(exclude_none=True, mode="json")
tool_output["index"] = current_index
content.append(tool_output)
elif (
chunk.type == "response.output_item.done"
and chunk.item.type == "custom_tool_call"
):
_advance(chunk.output_index)
tool_output = chunk.item.model_dump(exclude_none=True, mode="json")
tool_output["index"] = current_index
content.append(tool_output)
tool_call_chunks.append(
{
"type": "tool_call_chunk",
"name": chunk.item.name,
"args": json.dumps({"__arg1": chunk.item.input}),
"id": chunk.item.call_id,
"index": current_index,
}
)
elif chunk.type == "response.function_call_arguments.delta":
_advance(chunk.output_index)
tool_call_chunks.append(

View File

@ -0,0 +1,3 @@
from langchain_openai.tools.custom_tool import custom_tool
__all__ = ["custom_tool"]

View File

@ -0,0 +1,109 @@
import inspect
from collections.abc import Awaitable
from typing import Any, Callable
from langchain_core.tools import tool
def _make_wrapped_func(func: Callable[..., str]) -> Callable[..., list[dict[str, Any]]]:
def wrapped(x: str) -> list[dict[str, Any]]:
return [{"type": "custom_tool_call_output", "output": func(x)}]
return wrapped
def _make_wrapped_coroutine(
coroutine: Callable[..., Awaitable[str]],
) -> Callable[..., Awaitable[list[dict[str, Any]]]]:
async def wrapped(*args: Any, **kwargs: Any) -> list[dict[str, Any]]:
result = await coroutine(*args, **kwargs)
return [{"type": "custom_tool_call_output", "output": result}]
return wrapped
def custom_tool(*args: Any, **kwargs: Any) -> Any:
"""Decorator to create an OpenAI custom tool.
Custom tools allow for tools with (potentially long) freeform string inputs.
See below for an example using LangGraph:
.. code-block:: python
@custom_tool
def execute_code(code: str) -> str:
\"\"\"Execute python code.\"\"\"
return "27"
llm = ChatOpenAI(model="gpt-5", output_version="responses/v1")
agent = create_react_agent(llm, [execute_code])
input_message = {"role": "user", "content": "Use the tool to calculate 3^3."}
for step in agent.stream(
{"messages": [input_message]},
stream_mode="values",
):
step["messages"][-1].pretty_print()
You can also specify a format for a corresponding context-free grammar using the
``format`` kwarg:
.. code-block:: python
from langchain_openai import ChatOpenAI, custom_tool
from langgraph.prebuilt import create_react_agent
grammar = \"\"\"
start: expr
expr: term (SP ADD SP term)* -> add
| term
term: factor (SP MUL SP factor)* -> mul
| factor
factor: INT
SP: " "
ADD: "+"
MUL: "*"
%import common.INT
\"\"\"
format = {"type": "grammar", "syntax": "lark", "definition": grammar}
# highlight-next-line
@custom_tool(format=format)
def do_math(input_string: str) -> str:
\"\"\"Do a mathematical operation.\"\"\"
return "27"
llm = ChatOpenAI(model="gpt-5", output_version="responses/v1")
agent = create_react_agent(llm, [do_math])
input_message = {"role": "user", "content": "Use the tool to calculate 3^3."}
for step in agent.stream(
{"messages": [input_message]},
stream_mode="values",
):
step["messages"][-1].pretty_print()
"""
def decorator(func: Callable[..., Any]) -> Any:
metadata = {"type": "custom_tool"}
if "format" in kwargs:
metadata["format"] = kwargs.pop("format")
tool_obj = tool(infer_schema=False, **kwargs)(func)
tool_obj.metadata = metadata
tool_obj.description = func.__doc__
if inspect.iscoroutinefunction(func):
tool_obj.coroutine = _make_wrapped_coroutine(func)
else:
tool_obj.func = _make_wrapped_func(func)
return tool_obj
if args and callable(args[0]) and not kwargs:
return decorator(args[0])
return decorator

View File

@ -17,7 +17,7 @@ from langchain_core.messages import (
from pydantic import BaseModel
from typing_extensions import TypedDict
from langchain_openai import ChatOpenAI
from langchain_openai import ChatOpenAI, custom_tool
MODEL_NAME = "gpt-4o-mini"
@ -672,3 +672,32 @@ def test_image_generation_multi_turn() -> None:
_check_response(ai_message2)
tool_output2 = ai_message2.additional_kwargs["tool_outputs"][0]
assert set(tool_output2.keys()).issubset(expected_keys)
@pytest.mark.vcr()
def test_custom_tool() -> None:
@custom_tool
def execute_code(code: str) -> str:
"""Execute python code."""
return "27"
llm = ChatOpenAI(model="gpt-5", output_version="responses/v1").bind_tools(
[execute_code]
)
input_message = {"role": "user", "content": "Use the tool to evaluate 3^3."}
tool_call_message = llm.invoke([input_message])
assert isinstance(tool_call_message, AIMessage)
assert len(tool_call_message.tool_calls) == 1
tool_call = tool_call_message.tool_calls[0]
tool_message = execute_code.invoke(tool_call)
response = llm.invoke([input_message, tool_call_message, tool_message])
assert isinstance(response, AIMessage)
# Test streaming
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream([input_message]):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert len(full.tool_calls) == 1

View File

@ -7,6 +7,7 @@ EXPECTED_ALL = [
"AzureOpenAI",
"AzureChatOpenAI",
"AzureOpenAIEmbeddings",
"custom_tool",
]

View File

@ -0,0 +1,120 @@
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.tools import Tool
from langchain_openai import ChatOpenAI, custom_tool
def test_custom_tool() -> None:
@custom_tool
def my_tool(x: str) -> str:
"""Do thing."""
return "a" + x
# Test decorator
assert isinstance(my_tool, Tool)
assert my_tool.metadata == {"type": "custom_tool"}
assert my_tool.description == "Do thing."
result = my_tool.invoke(
{
"type": "tool_call",
"name": "my_tool",
"args": {"whatever": "b"},
"id": "abc",
"extras": {"type": "custom_tool_call"},
}
)
assert result == ToolMessage(
[{"type": "custom_tool_call_output", "output": "ab"}],
name="my_tool",
tool_call_id="abc",
)
# Test tool schema
## Test with format
@custom_tool(format={"type": "grammar", "syntax": "lark", "definition": "..."})
def another_tool(x: str) -> None:
"""Do thing."""
pass
llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True).bind_tools([another_tool])
assert llm.kwargs == { # type: ignore[attr-defined]
"tools": [
{
"type": "custom",
"name": "another_tool",
"description": "Do thing.",
"format": {"type": "grammar", "syntax": "lark", "definition": "..."},
}
]
}
llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True).bind_tools([my_tool])
assert llm.kwargs == { # type: ignore[attr-defined]
"tools": [{"type": "custom", "name": "my_tool", "description": "Do thing."}]
}
# Test passing messages back
message_history = [
HumanMessage("Use the tool"),
AIMessage(
[
{
"type": "custom_tool_call",
"id": "ctc_abc123",
"call_id": "abc",
"name": "my_tool",
"input": "a",
}
],
tool_calls=[
{
"type": "tool_call",
"name": "my_tool",
"args": {"__arg1": "a"},
"id": "abc",
}
],
),
result,
]
payload = llm._get_request_payload(message_history) # type: ignore[attr-defined]
expected_input = [
{"content": "Use the tool", "role": "user"},
{
"type": "custom_tool_call",
"id": "ctc_abc123",
"call_id": "abc",
"name": "my_tool",
"input": "a",
},
{"type": "custom_tool_call_output", "call_id": "abc", "output": "ab"},
]
assert payload["input"] == expected_input
async def test_async_custom_tool() -> None:
@custom_tool
async def my_async_tool(x: str) -> str:
"""Do async thing."""
return "a" + x
# Test decorator
assert isinstance(my_async_tool, Tool)
assert my_async_tool.metadata == {"type": "custom_tool"}
assert my_async_tool.description == "Do async thing."
result = await my_async_tool.ainvoke(
{
"type": "tool_call",
"name": "my_async_tool",
"args": {"whatever": "b"},
"id": "abc",
"extras": {"type": "custom_tool_call"},
}
)
assert result == ToolMessage(
[{"type": "custom_tool_call_output", "output": "ab"}],
name="my_async_tool",
tool_call_id="abc",
)