mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
Harrison/tool decorator (#790)
Co-authored-by: Jason Liu <jxnl@users.noreply.github.com> Co-authored-by: Jason Liu <jason@jxnl.coA>
This commit is contained in:
parent
5f73d06502
commit
1ad7973cc6
@ -10,15 +10,17 @@
|
||||
"When constructing your own agent, you will need to provide it with a list of Tools that it can use. A Tool is defined as below.\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"class Tool(NamedTuple):\n",
|
||||
"@dataclass \n",
|
||||
"class Tool:\n",
|
||||
" \"\"\"Interface for tools.\"\"\"\n",
|
||||
"\n",
|
||||
" name: str\n",
|
||||
" func: Callable[[str], str]\n",
|
||||
" description: Optional[str] = None\n",
|
||||
" return_direct: bool = True\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"The two required components of a Tool are the name and then the tool itself. A tool description is optional, as it is needed for some agents but not all."
|
||||
"The two required components of a Tool are the name and then the tool itself. A tool description is optional, as it is needed for some agents but not all. You can create these tools directly, but we also provide a decorator to easily convert any function into a tool."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -151,6 +153,94 @@
|
||||
"agent.run(\"Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "824eaf74",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using the `tool` decorator\n",
|
||||
"\n",
|
||||
"To make it easier to define custom tools, a `@tool` decorator is provided. This decorator can be used to quickly create a `Tool` from a simple function. The decorator uses the function name as the tool name by default, but this can be overridden by passing a string as the first argument. Additionally, the decorator will use the function's docstring as the tool's description."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "8f15307d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.agents import tool\n",
|
||||
"\n",
|
||||
"@tool\n",
|
||||
"def search_api(query: str) -> str:\n",
|
||||
" \"\"\"Searches the API for the query.\"\"\"\n",
|
||||
" return \"Results\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "0a23b91b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Tool(name='search_api', func=<function search_api at 0x10dad7d90>, description='search_api(query: str) -> str - Searches the API for the query.', return_direct=False)"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"search_api"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cc6ee8c1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can also provide arguments like the tool name and whether to return directly."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "28cdf04d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@tool(\"search\", return_direct=True)\n",
|
||||
"def search_api(query: str) -> str:\n",
|
||||
" \"\"\"Searches the API for the query.\"\"\"\n",
|
||||
" return \"Results\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "1085a4bd",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Tool(name='search', func=<function search_api at 0x112301bd0>, description='search(query: str) -> str - Searches the API for the query.', return_direct=True)"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"search_api"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1d0430d6",
|
||||
@ -432,7 +522,7 @@
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "cb23c3a7a387ab03496baa08507270f8e0861b23170e79d5edc545893cdca840"
|
||||
"hash": "e90c8aa204a57276aa905271aff2d11799d0acb3547adabc5892e639a5e45e34"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -7,7 +7,7 @@ from langchain.agents.loading import load_agent
|
||||
from langchain.agents.mrkl.base import MRKLChain, ZeroShotAgent
|
||||
from langchain.agents.react.base import ReActChain, ReActTextWorldAgent
|
||||
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.agents.tools import Tool, tool
|
||||
|
||||
__all__ = [
|
||||
"MRKLChain",
|
||||
@ -16,6 +16,7 @@ __all__ = [
|
||||
"AgentExecutor",
|
||||
"Agent",
|
||||
"Tool",
|
||||
"tool",
|
||||
"initialize_agent",
|
||||
"ZeroShotAgent",
|
||||
"ReActTextWorldAgent",
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Interface for tools."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
from inspect import signature
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -11,3 +12,65 @@ class Tool:
|
||||
func: Callable[[str], str]
|
||||
description: Optional[str] = None
|
||||
return_direct: bool = False
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> str:
|
||||
"""Make tools callable by piping through to `func`."""
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
|
||||
def tool(
|
||||
*args: Union[str, Callable], return_direct: bool = False
|
||||
) -> Union[Callable, Tool]:
|
||||
"""Make tools out of functions, can be used with or without arguments.
|
||||
|
||||
Requires:
|
||||
- Function must be of type (str) -> str
|
||||
- Function must have a docstring
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
@tool
|
||||
def search_api(query: str) -> str:
|
||||
# Searches the API for the query.
|
||||
return
|
||||
|
||||
@tool("search", return_direct=True)
|
||||
def search_api(query: str) -> str:
|
||||
# Searches the API for the query.
|
||||
return
|
||||
"""
|
||||
|
||||
def _make_with_name(tool_name: str) -> Callable:
|
||||
def _make_tool(func: Callable[[str], str]) -> Tool:
|
||||
assert func.__doc__, "Function must have a docstring"
|
||||
# Description example:
|
||||
# search_api(query: str) - Searches the API for the query.
|
||||
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
|
||||
tool = Tool(
|
||||
name=tool_name,
|
||||
func=func,
|
||||
description=description,
|
||||
return_direct=return_direct,
|
||||
)
|
||||
return tool
|
||||
|
||||
return _make_tool
|
||||
|
||||
if len(args) == 1 and isinstance(args[0], str):
|
||||
# if the argument is a string, then we use the string as the tool name
|
||||
# Example usage: @tool("search", return_direct=True)
|
||||
return _make_with_name(args[0])
|
||||
elif len(args) == 1 and callable(args[0]):
|
||||
# if the argument is a function, then we use the function name as the tool name
|
||||
# Example usage: @tool
|
||||
return _make_with_name(args[0].__name__)(args[0])
|
||||
elif len(args) == 0:
|
||||
# if there are no arguments, then we use the function name as the tool name
|
||||
# Example usage: @tool(return_direct=True)
|
||||
def _partial(func: Callable[[str], str]) -> Tool:
|
||||
return _make_with_name(func.__name__)(func)
|
||||
|
||||
return _partial
|
||||
else:
|
||||
raise ValueError("Too many arguments for tool decorator")
|
||||
|
67
tests/unit_tests/agents/test_tools.py
Normal file
67
tests/unit_tests/agents/test_tools.py
Normal file
@ -0,0 +1,67 @@
|
||||
"""Test tool utils."""
|
||||
import pytest
|
||||
|
||||
from langchain.agents.tools import Tool, tool
|
||||
|
||||
|
||||
def test_unnamed_decorator() -> None:
|
||||
"""Test functionality with unnamed decorator."""
|
||||
|
||||
@tool
|
||||
def search_api(query: str) -> str:
|
||||
"""Search the API for the query."""
|
||||
return "API result"
|
||||
|
||||
assert isinstance(search_api, Tool)
|
||||
assert search_api.name == "search_api"
|
||||
assert not search_api.return_direct
|
||||
assert search_api("test") == "API result"
|
||||
|
||||
|
||||
def test_named_tool_decorator() -> None:
|
||||
"""Test functionality when arguments are provided as input to decorator."""
|
||||
|
||||
@tool("search")
|
||||
def search_api(query: str) -> str:
|
||||
"""Search the API for the query."""
|
||||
return "API result"
|
||||
|
||||
assert isinstance(search_api, Tool)
|
||||
assert search_api.name == "search"
|
||||
assert not search_api.return_direct
|
||||
|
||||
|
||||
def test_named_tool_decorator_return_direct() -> None:
|
||||
"""Test functionality when arguments and return direct are provided as input."""
|
||||
|
||||
@tool("search", return_direct=True)
|
||||
def search_api(query: str) -> str:
|
||||
"""Search the API for the query."""
|
||||
return "API result"
|
||||
|
||||
assert isinstance(search_api, Tool)
|
||||
assert search_api.name == "search"
|
||||
assert search_api.return_direct
|
||||
|
||||
|
||||
def test_unnamed_tool_decorator_return_direct() -> None:
|
||||
"""Test functionality when only return direct is provided."""
|
||||
|
||||
@tool(return_direct=True)
|
||||
def search_api(query: str) -> str:
|
||||
"""Search the API for the query."""
|
||||
return "API result"
|
||||
|
||||
assert isinstance(search_api, Tool)
|
||||
assert search_api.name == "search_api"
|
||||
assert search_api.return_direct
|
||||
|
||||
|
||||
def test_missing_docstring() -> None:
|
||||
"""Test error is raised when docstring is missing."""
|
||||
# expect to throw a value error if theres no docstring
|
||||
with pytest.raises(AssertionError):
|
||||
|
||||
@tool
|
||||
def search_api(query: str) -> str:
|
||||
return "API result"
|
Loading…
Reference in New Issue
Block a user