mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-16 08:06:14 +00:00
125 lines
3.6 KiB
Python
125 lines
3.6 KiB
Python
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
|
from langchain_core.tools import Tool
|
|
|
|
from langchain_openai import ChatOpenAI, custom_tool
|
|
|
|
|
|
def test_custom_tool() -> None:
|
|
@custom_tool
|
|
def my_tool(x: str) -> str:
|
|
"""Do thing."""
|
|
return "a" + x
|
|
|
|
# Test decorator
|
|
assert isinstance(my_tool, Tool)
|
|
assert my_tool.metadata == {"type": "custom_tool"}
|
|
assert my_tool.description == "Do thing."
|
|
|
|
result = my_tool.invoke(
|
|
{
|
|
"type": "tool_call",
|
|
"name": "my_tool",
|
|
"args": {"whatever": "b"},
|
|
"id": "abc",
|
|
"extras": {"type": "custom_tool_call"},
|
|
}
|
|
)
|
|
assert result == ToolMessage(
|
|
[{"type": "custom_tool_call_output", "output": "ab"}],
|
|
name="my_tool",
|
|
tool_call_id="abc",
|
|
)
|
|
|
|
# Test tool schema
|
|
## Test with format
|
|
@custom_tool(format={"type": "grammar", "syntax": "lark", "definition": "..."})
|
|
def another_tool(x: str) -> None:
|
|
"""Do thing."""
|
|
pass
|
|
|
|
llm = ChatOpenAI(
|
|
model="gpt-4.1", use_responses_api=True, output_version="responses/v1"
|
|
).bind_tools([another_tool])
|
|
assert llm.kwargs == { # type: ignore[attr-defined]
|
|
"tools": [
|
|
{
|
|
"type": "custom",
|
|
"name": "another_tool",
|
|
"description": "Do thing.",
|
|
"format": {"type": "grammar", "syntax": "lark", "definition": "..."},
|
|
}
|
|
]
|
|
}
|
|
|
|
llm = ChatOpenAI(
|
|
model="gpt-4.1", use_responses_api=True, output_version="responses/v1"
|
|
).bind_tools([my_tool])
|
|
assert llm.kwargs == { # type: ignore[attr-defined]
|
|
"tools": [{"type": "custom", "name": "my_tool", "description": "Do thing."}]
|
|
}
|
|
|
|
# Test passing messages back
|
|
message_history = [
|
|
HumanMessage("Use the tool"),
|
|
AIMessage(
|
|
[
|
|
{
|
|
"type": "custom_tool_call",
|
|
"id": "ctc_abc123",
|
|
"call_id": "abc",
|
|
"name": "my_tool",
|
|
"input": "a",
|
|
}
|
|
],
|
|
tool_calls=[
|
|
{
|
|
"type": "tool_call",
|
|
"name": "my_tool",
|
|
"args": {"__arg1": "a"},
|
|
"id": "abc",
|
|
}
|
|
],
|
|
),
|
|
result,
|
|
]
|
|
payload = llm._get_request_payload(message_history) # type: ignore[attr-defined]
|
|
expected_input = [
|
|
{"content": "Use the tool", "role": "user"},
|
|
{
|
|
"type": "custom_tool_call",
|
|
"id": "ctc_abc123",
|
|
"call_id": "abc",
|
|
"name": "my_tool",
|
|
"input": "a",
|
|
},
|
|
{"type": "custom_tool_call_output", "call_id": "abc", "output": "ab"},
|
|
]
|
|
assert payload["input"] == expected_input
|
|
|
|
|
|
async def test_async_custom_tool() -> None:
|
|
@custom_tool
|
|
async def my_async_tool(x: str) -> str:
|
|
"""Do async thing."""
|
|
return "a" + x
|
|
|
|
# Test decorator
|
|
assert isinstance(my_async_tool, Tool)
|
|
assert my_async_tool.metadata == {"type": "custom_tool"}
|
|
assert my_async_tool.description == "Do async thing."
|
|
|
|
result = await my_async_tool.ainvoke(
|
|
{
|
|
"type": "tool_call",
|
|
"name": "my_async_tool",
|
|
"args": {"whatever": "b"},
|
|
"id": "abc",
|
|
"extras": {"type": "custom_tool_call"},
|
|
}
|
|
)
|
|
assert result == ToolMessage(
|
|
[{"type": "custom_tool_call_output", "output": "ab"}],
|
|
name="my_async_tool",
|
|
tool_call_id="abc",
|
|
)
|