mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 00:48:45 +00:00
Harrison/tool decorator (#790)
Co-authored-by: Jason Liu <jxnl@users.noreply.github.com> Co-authored-by: Jason Liu <jason@jxnl.coA>
This commit is contained in:
parent
5f73d06502
commit
1ad7973cc6
@ -10,15 +10,17 @@
|
|||||||
"When constructing your own agent, you will need to provide it with a list of Tools that it can use. A Tool is defined as below.\n",
|
"When constructing your own agent, you will need to provide it with a list of Tools that it can use. A Tool is defined as below.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"```python\n",
|
"```python\n",
|
||||||
"class Tool(NamedTuple):\n",
|
"@dataclass \n",
|
||||||
|
"class Tool:\n",
|
||||||
" \"\"\"Interface for tools.\"\"\"\n",
|
" \"\"\"Interface for tools.\"\"\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
" name: str\n",
|
" name: str\n",
|
||||||
" func: Callable[[str], str]\n",
|
" func: Callable[[str], str]\n",
|
||||||
" description: Optional[str] = None\n",
|
" description: Optional[str] = None\n",
|
||||||
|
" return_direct: bool = True\n",
|
||||||
"```\n",
|
"```\n",
|
||||||
"\n",
|
"\n",
|
||||||
"The two required components of a Tool are the name and then the tool itself. A tool description is optional, as it is needed for some agents but not all."
|
"The two required components of a Tool are the name and then the tool itself. A tool description is optional, as it is needed for some agents but not all. You can create these tools directly, but we also provide a decorator to easily convert any function into a tool."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -151,6 +153,94 @@
|
|||||||
"agent.run(\"Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?\")"
|
"agent.run(\"Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "824eaf74",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Using the `tool` decorator\n",
|
||||||
|
"\n",
|
||||||
|
"To make it easier to define custom tools, a `@tool` decorator is provided. This decorator can be used to quickly create a `Tool` from a simple function. The decorator uses the function name as the tool name by default, but this can be overridden by passing a string as the first argument. Additionally, the decorator will use the function's docstring as the tool's description."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "8f15307d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.agents import tool\n",
|
||||||
|
"\n",
|
||||||
|
"@tool\n",
|
||||||
|
"def search_api(query: str) -> str:\n",
|
||||||
|
" \"\"\"Searches the API for the query.\"\"\"\n",
|
||||||
|
" return \"Results\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "0a23b91b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"Tool(name='search_api', func=<function search_api at 0x10dad7d90>, description='search_api(query: str) -> str - Searches the API for the query.', return_direct=False)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"search_api"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "cc6ee8c1",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"You can also provide arguments like the tool name and whether to return directly."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "28cdf04d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"@tool(\"search\", return_direct=True)\n",
|
||||||
|
"def search_api(query: str) -> str:\n",
|
||||||
|
" \"\"\"Searches the API for the query.\"\"\"\n",
|
||||||
|
" return \"Results\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "1085a4bd",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"Tool(name='search', func=<function search_api at 0x112301bd0>, description='search(query: str) -> str - Searches the API for the query.', return_direct=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"search_api"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "1d0430d6",
|
"id": "1d0430d6",
|
||||||
@ -432,7 +522,7 @@
|
|||||||
},
|
},
|
||||||
"vscode": {
|
"vscode": {
|
||||||
"interpreter": {
|
"interpreter": {
|
||||||
"hash": "cb23c3a7a387ab03496baa08507270f8e0861b23170e79d5edc545893cdca840"
|
"hash": "e90c8aa204a57276aa905271aff2d11799d0acb3547adabc5892e639a5e45e34"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -7,7 +7,7 @@ from langchain.agents.loading import load_agent
|
|||||||
from langchain.agents.mrkl.base import MRKLChain, ZeroShotAgent
|
from langchain.agents.mrkl.base import MRKLChain, ZeroShotAgent
|
||||||
from langchain.agents.react.base import ReActChain, ReActTextWorldAgent
|
from langchain.agents.react.base import ReActChain, ReActTextWorldAgent
|
||||||
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain
|
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain
|
||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool, tool
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MRKLChain",
|
"MRKLChain",
|
||||||
@ -16,6 +16,7 @@ __all__ = [
|
|||||||
"AgentExecutor",
|
"AgentExecutor",
|
||||||
"Agent",
|
"Agent",
|
||||||
"Tool",
|
"Tool",
|
||||||
|
"tool",
|
||||||
"initialize_agent",
|
"initialize_agent",
|
||||||
"ZeroShotAgent",
|
"ZeroShotAgent",
|
||||||
"ReActTextWorldAgent",
|
"ReActTextWorldAgent",
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Interface for tools."""
|
"""Interface for tools."""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Callable, Optional
|
from inspect import signature
|
||||||
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -11,3 +12,65 @@ class Tool:
|
|||||||
func: Callable[[str], str]
|
func: Callable[[str], str]
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
return_direct: bool = False
|
return_direct: bool = False
|
||||||
|
|
||||||
|
def __call__(self, *args: Any, **kwargs: Any) -> str:
|
||||||
|
"""Make tools callable by piping through to `func`."""
|
||||||
|
return self.func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def tool(
|
||||||
|
*args: Union[str, Callable], return_direct: bool = False
|
||||||
|
) -> Union[Callable, Tool]:
|
||||||
|
"""Make tools out of functions, can be used with or without arguments.
|
||||||
|
|
||||||
|
Requires:
|
||||||
|
- Function must be of type (str) -> str
|
||||||
|
- Function must have a docstring
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def search_api(query: str) -> str:
|
||||||
|
# Searches the API for the query.
|
||||||
|
return
|
||||||
|
|
||||||
|
@tool("search", return_direct=True)
|
||||||
|
def search_api(query: str) -> str:
|
||||||
|
# Searches the API for the query.
|
||||||
|
return
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _make_with_name(tool_name: str) -> Callable:
|
||||||
|
def _make_tool(func: Callable[[str], str]) -> Tool:
|
||||||
|
assert func.__doc__, "Function must have a docstring"
|
||||||
|
# Description example:
|
||||||
|
# search_api(query: str) - Searches the API for the query.
|
||||||
|
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
|
||||||
|
tool = Tool(
|
||||||
|
name=tool_name,
|
||||||
|
func=func,
|
||||||
|
description=description,
|
||||||
|
return_direct=return_direct,
|
||||||
|
)
|
||||||
|
return tool
|
||||||
|
|
||||||
|
return _make_tool
|
||||||
|
|
||||||
|
if len(args) == 1 and isinstance(args[0], str):
|
||||||
|
# if the argument is a string, then we use the string as the tool name
|
||||||
|
# Example usage: @tool("search", return_direct=True)
|
||||||
|
return _make_with_name(args[0])
|
||||||
|
elif len(args) == 1 and callable(args[0]):
|
||||||
|
# if the argument is a function, then we use the function name as the tool name
|
||||||
|
# Example usage: @tool
|
||||||
|
return _make_with_name(args[0].__name__)(args[0])
|
||||||
|
elif len(args) == 0:
|
||||||
|
# if there are no arguments, then we use the function name as the tool name
|
||||||
|
# Example usage: @tool(return_direct=True)
|
||||||
|
def _partial(func: Callable[[str], str]) -> Tool:
|
||||||
|
return _make_with_name(func.__name__)(func)
|
||||||
|
|
||||||
|
return _partial
|
||||||
|
else:
|
||||||
|
raise ValueError("Too many arguments for tool decorator")
|
||||||
|
67
tests/unit_tests/agents/test_tools.py
Normal file
67
tests/unit_tests/agents/test_tools.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
"""Test tool utils."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.agents.tools import Tool, tool
|
||||||
|
|
||||||
|
|
||||||
|
def test_unnamed_decorator() -> None:
|
||||||
|
"""Test functionality with unnamed decorator."""
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def search_api(query: str) -> str:
|
||||||
|
"""Search the API for the query."""
|
||||||
|
return "API result"
|
||||||
|
|
||||||
|
assert isinstance(search_api, Tool)
|
||||||
|
assert search_api.name == "search_api"
|
||||||
|
assert not search_api.return_direct
|
||||||
|
assert search_api("test") == "API result"
|
||||||
|
|
||||||
|
|
||||||
|
def test_named_tool_decorator() -> None:
|
||||||
|
"""Test functionality when arguments are provided as input to decorator."""
|
||||||
|
|
||||||
|
@tool("search")
|
||||||
|
def search_api(query: str) -> str:
|
||||||
|
"""Search the API for the query."""
|
||||||
|
return "API result"
|
||||||
|
|
||||||
|
assert isinstance(search_api, Tool)
|
||||||
|
assert search_api.name == "search"
|
||||||
|
assert not search_api.return_direct
|
||||||
|
|
||||||
|
|
||||||
|
def test_named_tool_decorator_return_direct() -> None:
|
||||||
|
"""Test functionality when arguments and return direct are provided as input."""
|
||||||
|
|
||||||
|
@tool("search", return_direct=True)
|
||||||
|
def search_api(query: str) -> str:
|
||||||
|
"""Search the API for the query."""
|
||||||
|
return "API result"
|
||||||
|
|
||||||
|
assert isinstance(search_api, Tool)
|
||||||
|
assert search_api.name == "search"
|
||||||
|
assert search_api.return_direct
|
||||||
|
|
||||||
|
|
||||||
|
def test_unnamed_tool_decorator_return_direct() -> None:
|
||||||
|
"""Test functionality when only return direct is provided."""
|
||||||
|
|
||||||
|
@tool(return_direct=True)
|
||||||
|
def search_api(query: str) -> str:
|
||||||
|
"""Search the API for the query."""
|
||||||
|
return "API result"
|
||||||
|
|
||||||
|
assert isinstance(search_api, Tool)
|
||||||
|
assert search_api.name == "search_api"
|
||||||
|
assert search_api.return_direct
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_docstring() -> None:
|
||||||
|
"""Test error is raised when docstring is missing."""
|
||||||
|
# expect to throw a value error if theres no docstring
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def search_api(query: str) -> str:
|
||||||
|
return "API result"
|
Loading…
Reference in New Issue
Block a user