feat(openai): custom tools (#32449)

This commit is contained in:
ccurme 2025-08-07 17:30:01 -03:00 committed by GitHub
parent 145d38f7dd
commit ec2b34a02d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 488 additions and 4 deletions

View File

@ -447,6 +447,163 @@
")" ")"
] ]
}, },
{
"cell_type": "markdown",
"id": "c5d9d19d-8ab1-4d9d-b3a0-56ee4e89c528",
"metadata": {},
"source": [
"### Custom tools\n",
"\n",
":::info Requires ``langchain-openai>=0.3.29``\n",
"\n",
":::\n",
"\n",
"[Custom tools](https://platform.openai.com/docs/guides/function-calling#custom-tools) support tools with arbitrary string inputs. They can be particularly useful when you expect your string arguments to be long or complex."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "a47c809b-852f-46bd-8b9e-d9534c17213d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"================================\u001b[1m Human Message \u001b[0m=================================\n",
"\n",
"Use the tool to calculate 3^3.\n",
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
"\n",
"[{'id': 'rs_6894ff5747c0819d9b02fc5645b0be9c000169fd9fb68d99', 'summary': [], 'type': 'reasoning'}, {'call_id': 'call_7SYwMSQPbbEqFcKlKOpXeEux', 'input': 'print(3**3)', 'name': 'execute_code', 'type': 'custom_tool_call', 'id': 'ctc_6894ff5b9f54819d8155a63638d34103000169fd9fb68d99', 'status': 'completed'}]\n",
"Tool Calls:\n",
" execute_code (call_7SYwMSQPbbEqFcKlKOpXeEux)\n",
" Call ID: call_7SYwMSQPbbEqFcKlKOpXeEux\n",
" Args:\n",
" __arg1: print(3**3)\n",
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
"Name: execute_code\n",
"\n",
"[{'type': 'custom_tool_call_output', 'output': '27'}]\n",
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
"\n",
"[{'type': 'text', 'text': '27', 'annotations': [], 'id': 'msg_6894ff5db3b8819d9159b3a370a25843000169fd9fb68d99'}]\n"
]
}
],
"source": [
"from langchain_openai import ChatOpenAI, custom_tool\n",
"from langgraph.prebuilt import create_react_agent\n",
"\n",
"\n",
"@custom_tool\n",
"def execute_code(code: str) -> str:\n",
" \"\"\"Execute python code.\"\"\"\n",
" return \"27\"\n",
"\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-5\", output_version=\"responses/v1\")\n",
"\n",
"agent = create_react_agent(llm, [execute_code])\n",
"\n",
"input_message = {\"role\": \"user\", \"content\": \"Use the tool to calculate 3^3.\"}\n",
"for step in agent.stream(\n",
" {\"messages\": [input_message]},\n",
" stream_mode=\"values\",\n",
"):\n",
" step[\"messages\"][-1].pretty_print()"
]
},
{
"cell_type": "markdown",
"id": "5ef93be6-6d4c-4eea-acfd-248774074082",
"metadata": {},
"source": [
"<details>\n",
"<summary>Context-free grammars</summary>\n",
"\n",
"OpenAI supports the specification of a [context-free grammar](https://platform.openai.com/docs/guides/function-calling#context-free-grammars) for custom tool inputs in `lark` or `regex` format. See [OpenAI docs](https://platform.openai.com/docs/guides/function-calling#context-free-grammars) for details. The `format` parameter can be passed into `@custom_tool` as shown below:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "2ae04586-be33-49c6-8947-7867801d868f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"================================\u001b[1m Human Message \u001b[0m=================================\n",
"\n",
"Use the tool to calculate 3^3.\n",
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
"\n",
"[{'id': 'rs_689500828a8481a297ff0f98e328689c0681550c89797f43', 'summary': [], 'type': 'reasoning'}, {'call_id': 'call_jzH01RVhu6EFz7yUrOFXX55s', 'input': '3 * 3 * 3', 'name': 'do_math', 'type': 'custom_tool_call', 'id': 'ctc_6895008d57bc81a2b84d0993517a66b90681550c89797f43', 'status': 'completed'}]\n",
"Tool Calls:\n",
" do_math (call_jzH01RVhu6EFz7yUrOFXX55s)\n",
" Call ID: call_jzH01RVhu6EFz7yUrOFXX55s\n",
" Args:\n",
" __arg1: 3 * 3 * 3\n",
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
"Name: do_math\n",
"\n",
"[{'type': 'custom_tool_call_output', 'output': '27'}]\n",
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
"\n",
"[{'type': 'text', 'text': '27', 'annotations': [], 'id': 'msg_6895009776b881a2a25f0be8507d08f20681550c89797f43'}]\n"
]
}
],
"source": [
"from langchain_openai import ChatOpenAI, custom_tool\n",
"from langgraph.prebuilt import create_react_agent\n",
"\n",
"grammar = \"\"\"\n",
"start: expr\n",
"expr: term (SP ADD SP term)* -> add\n",
"| term\n",
"term: factor (SP MUL SP factor)* -> mul\n",
"| factor\n",
"factor: INT\n",
"SP: \" \"\n",
"ADD: \"+\"\n",
"MUL: \"*\"\n",
"%import common.INT\n",
"\"\"\"\n",
"\n",
"format_ = {\"type\": \"grammar\", \"syntax\": \"lark\", \"definition\": grammar}\n",
"\n",
"\n",
"# highlight-next-line\n",
"@custom_tool(format=format_)\n",
"def do_math(input_string: str) -> str:\n",
" \"\"\"Do a mathematical operation.\"\"\"\n",
" return \"27\"\n",
"\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-5\", output_version=\"responses/v1\")\n",
"\n",
"agent = create_react_agent(llm, [do_math])\n",
"\n",
"input_message = {\"role\": \"user\", \"content\": \"Use the tool to calculate 3^3.\"}\n",
"for step in agent.stream(\n",
" {\"messages\": [input_message]},\n",
" stream_mode=\"values\",\n",
"):\n",
" step[\"messages\"][-1].pretty_print()"
]
},
{
"cell_type": "markdown",
"id": "c63430c9-c7b0-4e92-a491-3f165dddeb8f",
"metadata": {},
"source": [
"</details>"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "84833dd0-17e9-4269-82ed-550639d65751", "id": "84833dd0-17e9-4269-82ed-550639d65751",

View File

@ -74,7 +74,14 @@ if TYPE_CHECKING:
from collections.abc import Sequence from collections.abc import Sequence
FILTERED_ARGS = ("run_manager", "callbacks") FILTERED_ARGS = ("run_manager", "callbacks")
TOOL_MESSAGE_BLOCK_TYPES = ("text", "image_url", "image", "json", "search_result") TOOL_MESSAGE_BLOCK_TYPES = (
"text",
"image_url",
"image",
"json",
"search_result",
"custom_tool_call_output",
)
class SchemaAnnotationError(TypeError): class SchemaAnnotationError(TypeError):

View File

@ -575,12 +575,23 @@ def convert_to_openai_tool(
Added support for OpenAI's image generation built-in tool. Added support for OpenAI's image generation built-in tool.
""" """
from langchain_core.tools import Tool
if isinstance(tool, dict): if isinstance(tool, dict):
if tool.get("type") in _WellKnownOpenAITools: if tool.get("type") in _WellKnownOpenAITools:
return tool return tool
# As of 03.12.25 can be "web_search_preview" or "web_search_preview_2025_03_11" # As of 03.12.25 can be "web_search_preview" or "web_search_preview_2025_03_11"
if (tool.get("type") or "").startswith("web_search_preview"): if (tool.get("type") or "").startswith("web_search_preview"):
return tool return tool
if isinstance(tool, Tool) and (tool.metadata or {}).get("type") == "custom_tool":
oai_tool = {
"type": "custom",
"name": tool.name,
"description": tool.description,
}
if tool.metadata is not None and "format" in tool.metadata:
oai_tool["format"] = tool.metadata["format"]
return oai_tool
oai_function = convert_to_openai_function(tool, strict=strict) oai_function = convert_to_openai_function(tool, strict=strict)
return {"type": "function", "function": oai_function} return {"type": "function", "function": oai_function}

View File

@ -1,6 +1,7 @@
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain_openai.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings from langchain_openai.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
from langchain_openai.llms import AzureOpenAI, OpenAI from langchain_openai.llms import AzureOpenAI, OpenAI
from langchain_openai.tools import custom_tool
__all__ = [ __all__ = [
"OpenAI", "OpenAI",
@ -9,4 +10,5 @@ __all__ = [
"AzureOpenAI", "AzureOpenAI",
"AzureChatOpenAI", "AzureChatOpenAI",
"AzureOpenAIEmbeddings", "AzureOpenAIEmbeddings",
"custom_tool",
] ]

View File

@ -3582,6 +3582,20 @@ def _make_computer_call_output_from_message(message: ToolMessage) -> dict:
return computer_call_output return computer_call_output
def _make_custom_tool_output_from_message(message: ToolMessage) -> Optional[dict]:
custom_tool_output = None
for block in message.content:
if isinstance(block, dict) and block.get("type") == "custom_tool_call_output":
custom_tool_output = {
"type": "custom_tool_call_output",
"call_id": message.tool_call_id,
"output": block.get("output") or "",
}
break
return custom_tool_output
def _pop_index_and_sub_index(block: dict) -> dict: def _pop_index_and_sub_index(block: dict) -> dict:
"""When streaming, langchain-core uses the ``index`` key to aggregate """When streaming, langchain-core uses the ``index`` key to aggregate
text blocks. OpenAI API does not support this key, so we need to remove it. text blocks. OpenAI API does not support this key, so we need to remove it.
@ -3608,7 +3622,10 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
msg.pop("name") msg.pop("name")
if msg["role"] == "tool": if msg["role"] == "tool":
tool_output = msg["content"] tool_output = msg["content"]
if lc_msg.additional_kwargs.get("type") == "computer_call_output": custom_tool_output = _make_custom_tool_output_from_message(lc_msg) # type: ignore[arg-type]
if custom_tool_output:
input_.append(custom_tool_output)
elif lc_msg.additional_kwargs.get("type") == "computer_call_output":
computer_call_output = _make_computer_call_output_from_message( computer_call_output = _make_computer_call_output_from_message(
cast(ToolMessage, lc_msg) cast(ToolMessage, lc_msg)
) )
@ -3663,6 +3680,7 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
"file_search_call", "file_search_call",
"function_call", "function_call",
"computer_call", "computer_call",
"custom_tool_call",
"code_interpreter_call", "code_interpreter_call",
"mcp_call", "mcp_call",
"mcp_list_tools", "mcp_list_tools",
@ -3690,7 +3708,8 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
content_call_ids = { content_call_ids = {
block["call_id"] block["call_id"]
for block in input_ for block in input_
if block.get("type") == "function_call" and "call_id" in block if block.get("type") in ("function_call", "custom_tool_call")
and "call_id" in block
} }
for tool_call in tool_calls: for tool_call in tool_calls:
if tool_call["id"] not in content_call_ids: if tool_call["id"] not in content_call_ids:
@ -3841,6 +3860,15 @@ def _construct_lc_result_from_responses_api(
"error": error, "error": error,
} }
invalid_tool_calls.append(tool_call) invalid_tool_calls.append(tool_call)
elif output.type == "custom_tool_call":
content_blocks.append(output.model_dump(exclude_none=True, mode="json"))
tool_call = {
"type": "tool_call",
"name": output.name,
"args": {"__arg1": output.input},
"id": output.call_id,
}
tool_calls.append(tool_call)
elif output.type in ( elif output.type in (
"reasoning", "reasoning",
"web_search_call", "web_search_call",
@ -4044,6 +4072,23 @@ def _convert_responses_chunk_to_generation_chunk(
tool_output = chunk.item.model_dump(exclude_none=True, mode="json") tool_output = chunk.item.model_dump(exclude_none=True, mode="json")
tool_output["index"] = current_index tool_output["index"] = current_index
content.append(tool_output) content.append(tool_output)
elif (
chunk.type == "response.output_item.done"
and chunk.item.type == "custom_tool_call"
):
_advance(chunk.output_index)
tool_output = chunk.item.model_dump(exclude_none=True, mode="json")
tool_output["index"] = current_index
content.append(tool_output)
tool_call_chunks.append(
{
"type": "tool_call_chunk",
"name": chunk.item.name,
"args": json.dumps({"__arg1": chunk.item.input}),
"id": chunk.item.call_id,
"index": current_index,
}
)
elif chunk.type == "response.function_call_arguments.delta": elif chunk.type == "response.function_call_arguments.delta":
_advance(chunk.output_index) _advance(chunk.output_index)
tool_call_chunks.append( tool_call_chunks.append(

View File

@ -0,0 +1,3 @@
from langchain_openai.tools.custom_tool import custom_tool
__all__ = ["custom_tool"]

View File

@ -0,0 +1,109 @@
import inspect
from collections.abc import Awaitable
from typing import Any, Callable
from langchain_core.tools import tool
def _make_wrapped_func(func: Callable[..., str]) -> Callable[..., list[dict[str, Any]]]:
def wrapped(x: str) -> list[dict[str, Any]]:
return [{"type": "custom_tool_call_output", "output": func(x)}]
return wrapped
def _make_wrapped_coroutine(
coroutine: Callable[..., Awaitable[str]],
) -> Callable[..., Awaitable[list[dict[str, Any]]]]:
async def wrapped(*args: Any, **kwargs: Any) -> list[dict[str, Any]]:
result = await coroutine(*args, **kwargs)
return [{"type": "custom_tool_call_output", "output": result}]
return wrapped
def custom_tool(*args: Any, **kwargs: Any) -> Any:
"""Decorator to create an OpenAI custom tool.
Custom tools allow for tools with (potentially long) freeform string inputs.
See below for an example using LangGraph:
.. code-block:: python
@custom_tool
def execute_code(code: str) -> str:
\"\"\"Execute python code.\"\"\"
return "27"
llm = ChatOpenAI(model="gpt-5", output_version="responses/v1")
agent = create_react_agent(llm, [execute_code])
input_message = {"role": "user", "content": "Use the tool to calculate 3^3."}
for step in agent.stream(
{"messages": [input_message]},
stream_mode="values",
):
step["messages"][-1].pretty_print()
You can also specify a format for a corresponding context-free grammar using the
``format`` kwarg:
.. code-block:: python
from langchain_openai import ChatOpenAI, custom_tool
from langgraph.prebuilt import create_react_agent
grammar = \"\"\"
start: expr
expr: term (SP ADD SP term)* -> add
| term
term: factor (SP MUL SP factor)* -> mul
| factor
factor: INT
SP: " "
ADD: "+"
MUL: "*"
%import common.INT
\"\"\"
format = {"type": "grammar", "syntax": "lark", "definition": grammar}
# highlight-next-line
@custom_tool(format=format)
def do_math(input_string: str) -> str:
\"\"\"Do a mathematical operation.\"\"\"
return "27"
llm = ChatOpenAI(model="gpt-5", output_version="responses/v1")
agent = create_react_agent(llm, [do_math])
input_message = {"role": "user", "content": "Use the tool to calculate 3^3."}
for step in agent.stream(
{"messages": [input_message]},
stream_mode="values",
):
step["messages"][-1].pretty_print()
"""
def decorator(func: Callable[..., Any]) -> Any:
metadata = {"type": "custom_tool"}
if "format" in kwargs:
metadata["format"] = kwargs.pop("format")
tool_obj = tool(infer_schema=False, **kwargs)(func)
tool_obj.metadata = metadata
tool_obj.description = func.__doc__
if inspect.iscoroutinefunction(func):
tool_obj.coroutine = _make_wrapped_coroutine(func)
else:
tool_obj.func = _make_wrapped_func(func)
return tool_obj
if args and callable(args[0]) and not kwargs:
return decorator(args[0])
return decorator

View File

@ -17,7 +17,7 @@ from langchain_core.messages import (
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import TypedDict from typing_extensions import TypedDict
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI, custom_tool
MODEL_NAME = "gpt-4o-mini" MODEL_NAME = "gpt-4o-mini"
@ -672,3 +672,32 @@ def test_image_generation_multi_turn() -> None:
_check_response(ai_message2) _check_response(ai_message2)
tool_output2 = ai_message2.additional_kwargs["tool_outputs"][0] tool_output2 = ai_message2.additional_kwargs["tool_outputs"][0]
assert set(tool_output2.keys()).issubset(expected_keys) assert set(tool_output2.keys()).issubset(expected_keys)
@pytest.mark.vcr()
def test_custom_tool() -> None:
@custom_tool
def execute_code(code: str) -> str:
"""Execute python code."""
return "27"
llm = ChatOpenAI(model="gpt-5", output_version="responses/v1").bind_tools(
[execute_code]
)
input_message = {"role": "user", "content": "Use the tool to evaluate 3^3."}
tool_call_message = llm.invoke([input_message])
assert isinstance(tool_call_message, AIMessage)
assert len(tool_call_message.tool_calls) == 1
tool_call = tool_call_message.tool_calls[0]
tool_message = execute_code.invoke(tool_call)
response = llm.invoke([input_message, tool_call_message, tool_message])
assert isinstance(response, AIMessage)
# Test streaming
full: Optional[BaseMessageChunk] = None
for chunk in llm.stream([input_message]):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert len(full.tool_calls) == 1

View File

@ -7,6 +7,7 @@ EXPECTED_ALL = [
"AzureOpenAI", "AzureOpenAI",
"AzureChatOpenAI", "AzureChatOpenAI",
"AzureOpenAIEmbeddings", "AzureOpenAIEmbeddings",
"custom_tool",
] ]

View File

@ -0,0 +1,120 @@
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.tools import Tool
from langchain_openai import ChatOpenAI, custom_tool
def test_custom_tool() -> None:
@custom_tool
def my_tool(x: str) -> str:
"""Do thing."""
return "a" + x
# Test decorator
assert isinstance(my_tool, Tool)
assert my_tool.metadata == {"type": "custom_tool"}
assert my_tool.description == "Do thing."
result = my_tool.invoke(
{
"type": "tool_call",
"name": "my_tool",
"args": {"whatever": "b"},
"id": "abc",
"extras": {"type": "custom_tool_call"},
}
)
assert result == ToolMessage(
[{"type": "custom_tool_call_output", "output": "ab"}],
name="my_tool",
tool_call_id="abc",
)
# Test tool schema
## Test with format
@custom_tool(format={"type": "grammar", "syntax": "lark", "definition": "..."})
def another_tool(x: str) -> None:
"""Do thing."""
pass
llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True).bind_tools([another_tool])
assert llm.kwargs == { # type: ignore[attr-defined]
"tools": [
{
"type": "custom",
"name": "another_tool",
"description": "Do thing.",
"format": {"type": "grammar", "syntax": "lark", "definition": "..."},
}
]
}
llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True).bind_tools([my_tool])
assert llm.kwargs == { # type: ignore[attr-defined]
"tools": [{"type": "custom", "name": "my_tool", "description": "Do thing."}]
}
# Test passing messages back
message_history = [
HumanMessage("Use the tool"),
AIMessage(
[
{
"type": "custom_tool_call",
"id": "ctc_abc123",
"call_id": "abc",
"name": "my_tool",
"input": "a",
}
],
tool_calls=[
{
"type": "tool_call",
"name": "my_tool",
"args": {"__arg1": "a"},
"id": "abc",
}
],
),
result,
]
payload = llm._get_request_payload(message_history) # type: ignore[attr-defined]
expected_input = [
{"content": "Use the tool", "role": "user"},
{
"type": "custom_tool_call",
"id": "ctc_abc123",
"call_id": "abc",
"name": "my_tool",
"input": "a",
},
{"type": "custom_tool_call_output", "call_id": "abc", "output": "ab"},
]
assert payload["input"] == expected_input
async def test_async_custom_tool() -> None:
@custom_tool
async def my_async_tool(x: str) -> str:
"""Do async thing."""
return "a" + x
# Test decorator
assert isinstance(my_async_tool, Tool)
assert my_async_tool.metadata == {"type": "custom_tool"}
assert my_async_tool.description == "Do async thing."
result = await my_async_tool.ainvoke(
{
"type": "tool_call",
"name": "my_async_tool",
"args": {"whatever": "b"},
"id": "abc",
"extras": {"type": "custom_tool_call"},
}
)
assert result == ToolMessage(
[{"type": "custom_tool_call_output", "output": "ab"}],
name="my_async_tool",
tool_call_id="abc",
)