mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-28 15:00:23 +00:00
feat(openai): custom tools (#32449)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI
|
||||
from langchain_openai.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||
from langchain_openai.llms import AzureOpenAI, OpenAI
|
||||
from langchain_openai.tools import custom_tool
|
||||
|
||||
__all__ = [
|
||||
"OpenAI",
|
||||
@@ -9,4 +10,5 @@ __all__ = [
|
||||
"AzureOpenAI",
|
||||
"AzureChatOpenAI",
|
||||
"AzureOpenAIEmbeddings",
|
||||
"custom_tool",
|
||||
]
|
||||
|
@@ -3582,6 +3582,20 @@ def _make_computer_call_output_from_message(message: ToolMessage) -> dict:
|
||||
return computer_call_output
|
||||
|
||||
|
||||
def _make_custom_tool_output_from_message(message: ToolMessage) -> Optional[dict]:
|
||||
custom_tool_output = None
|
||||
for block in message.content:
|
||||
if isinstance(block, dict) and block.get("type") == "custom_tool_call_output":
|
||||
custom_tool_output = {
|
||||
"type": "custom_tool_call_output",
|
||||
"call_id": message.tool_call_id,
|
||||
"output": block.get("output") or "",
|
||||
}
|
||||
break
|
||||
|
||||
return custom_tool_output
|
||||
|
||||
|
||||
def _pop_index_and_sub_index(block: dict) -> dict:
|
||||
"""When streaming, langchain-core uses the ``index`` key to aggregate
|
||||
text blocks. OpenAI API does not support this key, so we need to remove it.
|
||||
@@ -3608,7 +3622,10 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
|
||||
msg.pop("name")
|
||||
if msg["role"] == "tool":
|
||||
tool_output = msg["content"]
|
||||
if lc_msg.additional_kwargs.get("type") == "computer_call_output":
|
||||
custom_tool_output = _make_custom_tool_output_from_message(lc_msg) # type: ignore[arg-type]
|
||||
if custom_tool_output:
|
||||
input_.append(custom_tool_output)
|
||||
elif lc_msg.additional_kwargs.get("type") == "computer_call_output":
|
||||
computer_call_output = _make_computer_call_output_from_message(
|
||||
cast(ToolMessage, lc_msg)
|
||||
)
|
||||
@@ -3663,6 +3680,7 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
|
||||
"file_search_call",
|
||||
"function_call",
|
||||
"computer_call",
|
||||
"custom_tool_call",
|
||||
"code_interpreter_call",
|
||||
"mcp_call",
|
||||
"mcp_list_tools",
|
||||
@@ -3690,7 +3708,8 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
|
||||
content_call_ids = {
|
||||
block["call_id"]
|
||||
for block in input_
|
||||
if block.get("type") == "function_call" and "call_id" in block
|
||||
if block.get("type") in ("function_call", "custom_tool_call")
|
||||
and "call_id" in block
|
||||
}
|
||||
for tool_call in tool_calls:
|
||||
if tool_call["id"] not in content_call_ids:
|
||||
@@ -3841,6 +3860,15 @@ def _construct_lc_result_from_responses_api(
|
||||
"error": error,
|
||||
}
|
||||
invalid_tool_calls.append(tool_call)
|
||||
elif output.type == "custom_tool_call":
|
||||
content_blocks.append(output.model_dump(exclude_none=True, mode="json"))
|
||||
tool_call = {
|
||||
"type": "tool_call",
|
||||
"name": output.name,
|
||||
"args": {"__arg1": output.input},
|
||||
"id": output.call_id,
|
||||
}
|
||||
tool_calls.append(tool_call)
|
||||
elif output.type in (
|
||||
"reasoning",
|
||||
"web_search_call",
|
||||
@@ -4044,6 +4072,23 @@ def _convert_responses_chunk_to_generation_chunk(
|
||||
tool_output = chunk.item.model_dump(exclude_none=True, mode="json")
|
||||
tool_output["index"] = current_index
|
||||
content.append(tool_output)
|
||||
elif (
|
||||
chunk.type == "response.output_item.done"
|
||||
and chunk.item.type == "custom_tool_call"
|
||||
):
|
||||
_advance(chunk.output_index)
|
||||
tool_output = chunk.item.model_dump(exclude_none=True, mode="json")
|
||||
tool_output["index"] = current_index
|
||||
content.append(tool_output)
|
||||
tool_call_chunks.append(
|
||||
{
|
||||
"type": "tool_call_chunk",
|
||||
"name": chunk.item.name,
|
||||
"args": json.dumps({"__arg1": chunk.item.input}),
|
||||
"id": chunk.item.call_id,
|
||||
"index": current_index,
|
||||
}
|
||||
)
|
||||
elif chunk.type == "response.function_call_arguments.delta":
|
||||
_advance(chunk.output_index)
|
||||
tool_call_chunks.append(
|
||||
|
3
libs/partners/openai/langchain_openai/tools/__init__.py
Normal file
3
libs/partners/openai/langchain_openai/tools/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from langchain_openai.tools.custom_tool import custom_tool
|
||||
|
||||
__all__ = ["custom_tool"]
|
109
libs/partners/openai/langchain_openai/tools/custom_tool.py
Normal file
109
libs/partners/openai/langchain_openai/tools/custom_tool.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import inspect
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any, Callable
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
def _make_wrapped_func(func: Callable[..., str]) -> Callable[..., list[dict[str, Any]]]:
|
||||
def wrapped(x: str) -> list[dict[str, Any]]:
|
||||
return [{"type": "custom_tool_call_output", "output": func(x)}]
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def _make_wrapped_coroutine(
|
||||
coroutine: Callable[..., Awaitable[str]],
|
||||
) -> Callable[..., Awaitable[list[dict[str, Any]]]]:
|
||||
async def wrapped(*args: Any, **kwargs: Any) -> list[dict[str, Any]]:
|
||||
result = await coroutine(*args, **kwargs)
|
||||
return [{"type": "custom_tool_call_output", "output": result}]
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def custom_tool(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Decorator to create an OpenAI custom tool.
|
||||
|
||||
Custom tools allow for tools with (potentially long) freeform string inputs.
|
||||
|
||||
See below for an example using LangGraph:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@custom_tool
|
||||
def execute_code(code: str) -> str:
|
||||
\"\"\"Execute python code.\"\"\"
|
||||
return "27"
|
||||
|
||||
|
||||
llm = ChatOpenAI(model="gpt-5", output_version="responses/v1")
|
||||
|
||||
agent = create_react_agent(llm, [execute_code])
|
||||
|
||||
input_message = {"role": "user", "content": "Use the tool to calculate 3^3."}
|
||||
for step in agent.stream(
|
||||
{"messages": [input_message]},
|
||||
stream_mode="values",
|
||||
):
|
||||
step["messages"][-1].pretty_print()
|
||||
|
||||
You can also specify a format for a corresponding context-free grammar using the
|
||||
``format`` kwarg:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_openai import ChatOpenAI, custom_tool
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
grammar = \"\"\"
|
||||
start: expr
|
||||
expr: term (SP ADD SP term)* -> add
|
||||
| term
|
||||
term: factor (SP MUL SP factor)* -> mul
|
||||
| factor
|
||||
factor: INT
|
||||
SP: " "
|
||||
ADD: "+"
|
||||
MUL: "*"
|
||||
%import common.INT
|
||||
\"\"\"
|
||||
|
||||
format = {"type": "grammar", "syntax": "lark", "definition": grammar}
|
||||
|
||||
# highlight-next-line
|
||||
@custom_tool(format=format)
|
||||
def do_math(input_string: str) -> str:
|
||||
\"\"\"Do a mathematical operation.\"\"\"
|
||||
return "27"
|
||||
|
||||
|
||||
llm = ChatOpenAI(model="gpt-5", output_version="responses/v1")
|
||||
|
||||
agent = create_react_agent(llm, [do_math])
|
||||
|
||||
input_message = {"role": "user", "content": "Use the tool to calculate 3^3."}
|
||||
for step in agent.stream(
|
||||
{"messages": [input_message]},
|
||||
stream_mode="values",
|
||||
):
|
||||
step["messages"][-1].pretty_print()
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., Any]) -> Any:
|
||||
metadata = {"type": "custom_tool"}
|
||||
if "format" in kwargs:
|
||||
metadata["format"] = kwargs.pop("format")
|
||||
tool_obj = tool(infer_schema=False, **kwargs)(func)
|
||||
tool_obj.metadata = metadata
|
||||
tool_obj.description = func.__doc__
|
||||
if inspect.iscoroutinefunction(func):
|
||||
tool_obj.coroutine = _make_wrapped_coroutine(func)
|
||||
else:
|
||||
tool_obj.func = _make_wrapped_func(func)
|
||||
return tool_obj
|
||||
|
||||
if args and callable(args[0]) and not kwargs:
|
||||
return decorator(args[0])
|
||||
|
||||
return decorator
|
BIN
libs/partners/openai/tests/cassettes/test_custom_tool.yaml.gz
Normal file
BIN
libs/partners/openai/tests/cassettes/test_custom_tool.yaml.gz
Normal file
Binary file not shown.
@@ -17,7 +17,7 @@ from langchain_core.messages import (
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_openai import ChatOpenAI, custom_tool
|
||||
|
||||
MODEL_NAME = "gpt-4o-mini"
|
||||
|
||||
@@ -672,3 +672,32 @@ def test_image_generation_multi_turn() -> None:
|
||||
_check_response(ai_message2)
|
||||
tool_output2 = ai_message2.additional_kwargs["tool_outputs"][0]
|
||||
assert set(tool_output2.keys()).issubset(expected_keys)
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_custom_tool() -> None:
|
||||
@custom_tool
|
||||
def execute_code(code: str) -> str:
|
||||
"""Execute python code."""
|
||||
return "27"
|
||||
|
||||
llm = ChatOpenAI(model="gpt-5", output_version="responses/v1").bind_tools(
|
||||
[execute_code]
|
||||
)
|
||||
|
||||
input_message = {"role": "user", "content": "Use the tool to evaluate 3^3."}
|
||||
tool_call_message = llm.invoke([input_message])
|
||||
assert isinstance(tool_call_message, AIMessage)
|
||||
assert len(tool_call_message.tool_calls) == 1
|
||||
tool_call = tool_call_message.tool_calls[0]
|
||||
tool_message = execute_code.invoke(tool_call)
|
||||
response = llm.invoke([input_message, tool_call_message, tool_message])
|
||||
assert isinstance(response, AIMessage)
|
||||
|
||||
# Test streaming
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
for chunk in llm.stream([input_message]):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert len(full.tool_calls) == 1
|
||||
|
@@ -7,6 +7,7 @@ EXPECTED_ALL = [
|
||||
"AzureOpenAI",
|
||||
"AzureChatOpenAI",
|
||||
"AzureOpenAIEmbeddings",
|
||||
"custom_tool",
|
||||
]
|
||||
|
||||
|
||||
|
120
libs/partners/openai/tests/unit_tests/test_tools.py
Normal file
120
libs/partners/openai/tests/unit_tests/test_tools.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.tools import Tool
|
||||
|
||||
from langchain_openai import ChatOpenAI, custom_tool
|
||||
|
||||
|
||||
def test_custom_tool() -> None:
|
||||
@custom_tool
|
||||
def my_tool(x: str) -> str:
|
||||
"""Do thing."""
|
||||
return "a" + x
|
||||
|
||||
# Test decorator
|
||||
assert isinstance(my_tool, Tool)
|
||||
assert my_tool.metadata == {"type": "custom_tool"}
|
||||
assert my_tool.description == "Do thing."
|
||||
|
||||
result = my_tool.invoke(
|
||||
{
|
||||
"type": "tool_call",
|
||||
"name": "my_tool",
|
||||
"args": {"whatever": "b"},
|
||||
"id": "abc",
|
||||
"extras": {"type": "custom_tool_call"},
|
||||
}
|
||||
)
|
||||
assert result == ToolMessage(
|
||||
[{"type": "custom_tool_call_output", "output": "ab"}],
|
||||
name="my_tool",
|
||||
tool_call_id="abc",
|
||||
)
|
||||
|
||||
# Test tool schema
|
||||
## Test with format
|
||||
@custom_tool(format={"type": "grammar", "syntax": "lark", "definition": "..."})
|
||||
def another_tool(x: str) -> None:
|
||||
"""Do thing."""
|
||||
pass
|
||||
|
||||
llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True).bind_tools([another_tool])
|
||||
assert llm.kwargs == { # type: ignore[attr-defined]
|
||||
"tools": [
|
||||
{
|
||||
"type": "custom",
|
||||
"name": "another_tool",
|
||||
"description": "Do thing.",
|
||||
"format": {"type": "grammar", "syntax": "lark", "definition": "..."},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
llm = ChatOpenAI(model="gpt-4.1", use_responses_api=True).bind_tools([my_tool])
|
||||
assert llm.kwargs == { # type: ignore[attr-defined]
|
||||
"tools": [{"type": "custom", "name": "my_tool", "description": "Do thing."}]
|
||||
}
|
||||
|
||||
# Test passing messages back
|
||||
message_history = [
|
||||
HumanMessage("Use the tool"),
|
||||
AIMessage(
|
||||
[
|
||||
{
|
||||
"type": "custom_tool_call",
|
||||
"id": "ctc_abc123",
|
||||
"call_id": "abc",
|
||||
"name": "my_tool",
|
||||
"input": "a",
|
||||
}
|
||||
],
|
||||
tool_calls=[
|
||||
{
|
||||
"type": "tool_call",
|
||||
"name": "my_tool",
|
||||
"args": {"__arg1": "a"},
|
||||
"id": "abc",
|
||||
}
|
||||
],
|
||||
),
|
||||
result,
|
||||
]
|
||||
payload = llm._get_request_payload(message_history) # type: ignore[attr-defined]
|
||||
expected_input = [
|
||||
{"content": "Use the tool", "role": "user"},
|
||||
{
|
||||
"type": "custom_tool_call",
|
||||
"id": "ctc_abc123",
|
||||
"call_id": "abc",
|
||||
"name": "my_tool",
|
||||
"input": "a",
|
||||
},
|
||||
{"type": "custom_tool_call_output", "call_id": "abc", "output": "ab"},
|
||||
]
|
||||
assert payload["input"] == expected_input
|
||||
|
||||
|
||||
async def test_async_custom_tool() -> None:
|
||||
@custom_tool
|
||||
async def my_async_tool(x: str) -> str:
|
||||
"""Do async thing."""
|
||||
return "a" + x
|
||||
|
||||
# Test decorator
|
||||
assert isinstance(my_async_tool, Tool)
|
||||
assert my_async_tool.metadata == {"type": "custom_tool"}
|
||||
assert my_async_tool.description == "Do async thing."
|
||||
|
||||
result = await my_async_tool.ainvoke(
|
||||
{
|
||||
"type": "tool_call",
|
||||
"name": "my_async_tool",
|
||||
"args": {"whatever": "b"},
|
||||
"id": "abc",
|
||||
"extras": {"type": "custom_tool_call"},
|
||||
}
|
||||
)
|
||||
assert result == ToolMessage(
|
||||
[{"type": "custom_tool_call_output", "output": "ab"}],
|
||||
name="my_async_tool",
|
||||
tool_call_id="abc",
|
||||
)
|
Reference in New Issue
Block a user