Files
langchain/docs/docs/modules/agents/tools/custom_tools.ipynb
Harrison Chase f20c56db41 [documentation] documentation revamp (#15281)
needs new versions of langchain-core and langchain

---------

Co-authored-by: Nuno Campos <nuno@langchain.dev>
2023-12-29 14:51:06 -08:00

577 lines
22 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "5436020b",
"metadata": {},
"source": [
"# Defining Custom Tools\n",
"\n",
"When constructing your own agent, you will need to provide it with a list of Tools that it can use. Besides the actual function that is called, the Tool consists of several components:\n",
"\n",
"- `name` (str), is required and must be unique within a set of tools provided to an agent\n",
"- `description` (str), is optional but recommended, as it is used by an agent to determine tool use\n",
"- `args_schema` (Pydantic BaseModel), is optional but recommended, can be used to provide more information (e.g., few-shot examples) or validation for expected parameters.\n",
"\n",
"\n",
"There are multiple ways to define a tool. In this guide, we will walk through how to do for two functions:\n",
"\n",
"1. A made up search function that always returns the string \"LangChain\"\n",
"2. A multiplier function that will multiply two numbers by eachother\n",
"\n",
"The biggest difference here is that the first function only requires one input, while the second one requires multiple. Many agents only work with functions that require single inputs, so it's important to know how to work with those. For the most part, defining these custom tools is the same, but there are some differences."
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "1aaba18c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# Import things that are needed generically\n",
"from langchain.pydantic_v1 import BaseModel, Field\n",
"from langchain.tools import BaseTool, StructuredTool, tool"
]
},
{
"cell_type": "markdown",
"id": "c7326b23",
"metadata": {},
"source": [
"## @tool decorator\n",
"\n",
"This `@tool` decorator is the simplest way to define a custom tool. The decorator uses the function name as the tool name by default, but this can be overridden by passing a string as the first argument. Additionally, the decorator will use the function's docstring as the tool's description - so a docstring MUST be provided. "
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b0ce7de8",
"metadata": {},
"outputs": [],
"source": [
"@tool\n",
"def search(query: str) -> str:\n",
" \"\"\"Look up things online.\"\"\"\n",
" return \"LangChain\""
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e889fa34",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"search\n",
"search(query: str) -> str - Look up things online.\n",
"{'query': {'title': 'Query', 'type': 'string'}}\n"
]
}
],
"source": [
"print(search.name)\n",
"print(search.description)\n",
"print(search.args)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "0b9694d9",
"metadata": {},
"outputs": [],
"source": [
"@tool\n",
"def multiply(a: int, b: int) -> int:\n",
" \"\"\"Multiply two numbers.\"\"\"\n",
" return a * b"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "d7f9395b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"multiply\n",
"multiply(a: int, b: int) -> int - Multiply two numbers.\n",
"{'a': {'title': 'A', 'type': 'integer'}, 'b': {'title': 'B', 'type': 'integer'}}\n"
]
}
],
"source": [
"print(multiply.name)\n",
"print(multiply.description)\n",
"print(multiply.args)"
]
},
{
"cell_type": "markdown",
"id": "98d6eee9",
"metadata": {},
"source": [
"You can also customize the tool name and JSON args by passing them into the tool decorator."
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "dbbf4b6c",
"metadata": {},
"outputs": [],
"source": [
"class SearchInput(BaseModel):\n",
" query: str = Field(description=\"should be a search query\")\n",
"\n",
"\n",
"@tool(\"search-tool\", args_schema=SearchInput, return_direct=True)\n",
"def search(query: str) -> str:\n",
" \"\"\"Look up things online.\"\"\"\n",
" return \"LangChain\""
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "5950ce32",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"search-tool\n",
"search-tool(query: str) -> str - Look up things online.\n",
"{'query': {'title': 'Query', 'description': 'should be a search query', 'type': 'string'}}\n",
"True\n"
]
}
],
"source": [
"print(search.name)\n",
"print(search.description)\n",
"print(search.args)\n",
"print(search.return_direct)"
]
},
{
"cell_type": "markdown",
"id": "9d11e80c",
"metadata": {},
"source": [
"## Subclass BaseTool\n",
"\n",
"You can also explicitly define a custom tool by subclassing the BaseTool class. This provides maximal control over the tool definition, but is a bit more work."
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "1dad8f8e",
"metadata": {},
"outputs": [],
"source": [
"from typing import Optional, Type\n",
"\n",
"from langchain.callbacks.manager import (\n",
" AsyncCallbackManagerForToolRun,\n",
" CallbackManagerForToolRun,\n",
")\n",
"\n",
"\n",
"class SearchInput(BaseModel):\n",
" query: str = Field(description=\"should be a search query\")\n",
"\n",
"\n",
"class CalculatorInput(BaseModel):\n",
" a: int = Field(description=\"first number\")\n",
" b: int = Field(description=\"second number\")\n",
"\n",
"\n",
"class CustomSearchTool(BaseTool):\n",
" name = \"custom_search\"\n",
" description = \"useful for when you need to answer questions about current events\"\n",
" args_schema: Type[BaseModel] = SearchInput\n",
"\n",
" def _run(\n",
" self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None\n",
" ) -> str:\n",
" \"\"\"Use the tool.\"\"\"\n",
" return \"LangChain\"\n",
"\n",
" async def _arun(\n",
" self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None\n",
" ) -> str:\n",
" \"\"\"Use the tool asynchronously.\"\"\"\n",
" raise NotImplementedError(\"custom_search does not support async\")\n",
"\n",
"\n",
"class CustomCalculatorTool(BaseTool):\n",
" name = \"Calculator\"\n",
" description = \"useful for when you need to answer questions about math\"\n",
" args_schema: Type[BaseModel] = CalculatorInput\n",
" return_direct: bool = True\n",
"\n",
" def _run(\n",
" self, a: int, b: int, run_manager: Optional[CallbackManagerForToolRun] = None\n",
" ) -> str:\n",
" \"\"\"Use the tool.\"\"\"\n",
" return a * b\n",
"\n",
" async def _arun(\n",
" self,\n",
" a: int,\n",
" b: int,\n",
" run_manager: Optional[AsyncCallbackManagerForToolRun] = None,\n",
" ) -> str:\n",
" \"\"\"Use the tool asynchronously.\"\"\"\n",
" raise NotImplementedError(\"Calculator does not support async\")"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "89933e27",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"custom_search\n",
"useful for when you need to answer questions about current events\n",
"{'query': {'title': 'Query', 'description': 'should be a search query', 'type': 'string'}}\n"
]
}
],
"source": [
"search = CustomSearchTool()\n",
"print(search.name)\n",
"print(search.description)\n",
"print(search.args)"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "bb551c33",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Calculator\n",
"useful for when you need to answer questions about math\n",
"{'a': {'title': 'A', 'description': 'first number', 'type': 'integer'}, 'b': {'title': 'B', 'description': 'second number', 'type': 'integer'}}\n",
"True\n"
]
}
],
"source": [
"multiply = CustomCalculatorTool()\n",
"print(multiply.name)\n",
"print(multiply.description)\n",
"print(multiply.args)\n",
"print(multiply.return_direct)"
]
},
{
"cell_type": "markdown",
"id": "b63fcc3b",
"metadata": {},
"source": [
"## StructuredTool dataclass\n",
"\n",
"You can also use a `StructuredTool` dataclass. This methods is a mix between the previous two. It's more convenient than inheriting from the BaseTool class, but provides more functionality than just using a decorator."
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "56ff7670",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def search_function(query: str):\n",
" return \"LangChain\"\n",
"\n",
"\n",
"search = StructuredTool.from_function(\n",
" func=search_function,\n",
" name=\"Search\",\n",
" description=\"useful for when you need to answer questions about current events\",\n",
" # coroutine= ... <- you can specify an async method if desired as well\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "d3fd3896",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Search\n",
"Search(query: str) - useful for when you need to answer questions about current events\n",
"{'query': {'title': 'Query', 'type': 'string'}}\n"
]
}
],
"source": [
"print(search.name)\n",
"print(search.description)\n",
"print(search.args)"
]
},
{
"cell_type": "markdown",
"id": "e9b560f7",
"metadata": {},
"source": [
"You can also define a custom `args_schema` to provide more information about inputs."
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "712c1967",
"metadata": {},
"outputs": [],
"source": [
"class CalculatorInput(BaseModel):\n",
" a: int = Field(description=\"first number\")\n",
" b: int = Field(description=\"second number\")\n",
"\n",
"\n",
"def multiply(a: int, b: int) -> int:\n",
" \"\"\"Multiply two numbers.\"\"\"\n",
" return a * b\n",
"\n",
"\n",
"calculator = StructuredTool.from_function(\n",
" func=multiply,\n",
" name=\"Calculator\",\n",
" description=\"multiply numbers\",\n",
" args_schema=CalculatorInput,\n",
" return_direct=True,\n",
" # coroutine= ... <- you can specify an async method if desired as well\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "f634081e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Calculator\n",
"Calculator(a: int, b: int) -> int - multiply numbers\n",
"{'a': {'title': 'A', 'description': 'first number', 'type': 'integer'}, 'b': {'title': 'B', 'description': 'second number', 'type': 'integer'}}\n"
]
}
],
"source": [
"print(calculator.name)\n",
"print(calculator.description)\n",
"print(calculator.args)"
]
},
{
"cell_type": "markdown",
"id": "f1da459d",
"metadata": {},
"source": [
"## Handling Tool Errors \n",
"When a tool encounters an error and the exception is not caught, the agent will stop executing. If you want the agent to continue execution, you can raise a `ToolException` and set `handle_tool_error` accordingly. \n",
"\n",
"When `ToolException` is thrown, the agent will not stop working, but will handle the exception according to the `handle_tool_error` variable of the tool, and the processing result will be returned to the agent as observation, and printed in red.\n",
"\n",
"You can set `handle_tool_error` to `True`, set it a unified string value, or set it as a function. If it's set as a function, the function should take a `ToolException` as a parameter and return a `str` value.\n",
"\n",
"Please note that only raising a `ToolException` won't be effective. You need to first set the `handle_tool_error` of the tool because its default value is `False`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f8bf4668",
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.tools import ToolException\n",
"\n",
"\n",
"def search_tool1(s: str):\n",
" raise ToolException(\"The search tool1 is not available.\")"
]
},
{
"cell_type": "markdown",
"id": "7fb56757",
"metadata": {},
"source": [
"First, let's see what happens if we don't set `handle_tool_error` - it will error."
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "f3dfbcb0",
"metadata": {},
"outputs": [
{
"ename": "ToolException",
"evalue": "The search tool1 is not available.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mToolException\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[58], line 7\u001b[0m\n\u001b[1;32m 1\u001b[0m search \u001b[38;5;241m=\u001b[39m StructuredTool\u001b[38;5;241m.\u001b[39mfrom_function(\n\u001b[1;32m 2\u001b[0m func\u001b[38;5;241m=\u001b[39msearch_tool1,\n\u001b[1;32m 3\u001b[0m name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSearch_tool1\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 4\u001b[0m description\u001b[38;5;241m=\u001b[39mdescription,\n\u001b[1;32m 5\u001b[0m )\n\u001b[0;32m----> 7\u001b[0m \u001b[43msearch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtest\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/workplace/langchain/libs/core/langchain_core/tools.py:344\u001b[0m, in \u001b[0;36mBaseTool.run\u001b[0;34m(self, tool_input, verbose, start_color, color, callbacks, tags, metadata, run_name, **kwargs)\u001b[0m\n\u001b[1;32m 342\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhandle_tool_error:\n\u001b[1;32m 343\u001b[0m run_manager\u001b[38;5;241m.\u001b[39mon_tool_error(e)\n\u001b[0;32m--> 344\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m 345\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhandle_tool_error, \u001b[38;5;28mbool\u001b[39m):\n\u001b[1;32m 346\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m e\u001b[38;5;241m.\u001b[39margs:\n",
"File \u001b[0;32m~/workplace/langchain/libs/core/langchain_core/tools.py:337\u001b[0m, in \u001b[0;36mBaseTool.run\u001b[0;34m(self, tool_input, verbose, start_color, color, callbacks, tags, metadata, run_name, **kwargs)\u001b[0m\n\u001b[1;32m 334\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 335\u001b[0m tool_args, tool_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_to_args_and_kwargs(parsed_input)\n\u001b[1;32m 336\u001b[0m observation \u001b[38;5;241m=\u001b[39m (\n\u001b[0;32m--> 337\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtool_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrun_manager\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrun_manager\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtool_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 338\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m new_arg_supported\n\u001b[1;32m 339\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run(\u001b[38;5;241m*\u001b[39mtool_args, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mtool_kwargs)\n\u001b[1;32m 340\u001b[0m )\n\u001b[1;32m 341\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m ToolException \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 342\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhandle_tool_error:\n",
"File \u001b[0;32m~/workplace/langchain/libs/core/langchain_core/tools.py:631\u001b[0m, in \u001b[0;36mStructuredTool._run\u001b[0;34m(self, run_manager, *args, **kwargs)\u001b[0m\n\u001b[1;32m 622\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfunc:\n\u001b[1;32m 623\u001b[0m new_argument_supported \u001b[38;5;241m=\u001b[39m signature(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfunc)\u001b[38;5;241m.\u001b[39mparameters\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcallbacks\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 624\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\n\u001b[1;32m 625\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfunc(\n\u001b[1;32m 626\u001b[0m \u001b[38;5;241m*\u001b[39margs,\n\u001b[1;32m 627\u001b[0m callbacks\u001b[38;5;241m=\u001b[39mrun_manager\u001b[38;5;241m.\u001b[39mget_child() \u001b[38;5;28;01mif\u001b[39;00m run_manager \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 628\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 629\u001b[0m )\n\u001b[1;32m 630\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m new_argument_supported\n\u001b[0;32m--> 631\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 632\u001b[0m )\n\u001b[1;32m 633\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTool does not support sync\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"Cell \u001b[0;32mIn[55], line 5\u001b[0m, in \u001b[0;36msearch_tool1\u001b[0;34m(s)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msearch_tool1\u001b[39m(s: \u001b[38;5;28mstr\u001b[39m):\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m ToolException(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe search tool1 is not available.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"\u001b[0;31mToolException\u001b[0m: The search tool1 is not available."
]
}
],
"source": [
"search = StructuredTool.from_function(\n",
" func=search_tool1,\n",
" name=\"Search_tool1\",\n",
" description=\"A bad tool\",\n",
")\n",
"\n",
"search.run(\"test\")"
]
},
{
"cell_type": "markdown",
"id": "d2475acd",
"metadata": {},
"source": [
"Now, let's set `handle_tool_error` to be True"
]
},
{
"cell_type": "code",
"execution_count": 59,
"id": "ab81e0f0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'The search tool1 is not available.'"
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"search = StructuredTool.from_function(\n",
" func=search_tool1,\n",
" name=\"Search_tool1\",\n",
" description=\"A bad tool\",\n",
" handle_tool_error=True,\n",
")\n",
"\n",
"search.run(\"test\")"
]
},
{
"cell_type": "markdown",
"id": "dafbbcbe",
"metadata": {},
"source": [
"We can also define a custom way to handle the tool error"
]
},
{
"cell_type": "code",
"execution_count": 60,
"id": "ad16fbcf",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'The following errors occurred during tool execution:The search tool1 is not available.Please try another tool.'"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def _handle_error(error: ToolException) -> str:\n",
" return (\n",
" \"The following errors occurred during tool execution:\"\n",
" + error.args[0]\n",
" + \"Please try another tool.\"\n",
" )\n",
"\n",
"\n",
"search = StructuredTool.from_function(\n",
" func=search_tool1,\n",
" name=\"Search_tool1\",\n",
" description=\"A bad tool\",\n",
" handle_tool_error=_handle_error,\n",
")\n",
"\n",
"search.run(\"test\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.1"
},
"vscode": {
"interpreter": {
"hash": "e90c8aa204a57276aa905271aff2d11799d0acb3547adabc5892e639a5e45e34"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}