Compare commits

...

2 Commits

Author SHA1 Message Date
vowelparrot
5b72451f9d pass in pydantic object 2023-04-26 14:13:07 -07:00
vowelparrot
5c0d36a0db Structured Tool Refactor 2023-04-26 11:50:05 -07:00
12 changed files with 706 additions and 783 deletions

View File

@@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "5436020b",
"metadata": {},
@@ -12,7 +13,6 @@
"- name (str), is required and must be unique within a set of tools provided to an agent\n",
"- description (str), is optional but recommended, as it is used by an agent to determine tool use\n",
"- return_direct (bool), defaults to False\n",
"- args_schema (Pydantic BaseModel), is optional but recommended, can be used to provide more information or validation for expected parameters.\n",
"\n",
"The function that should be called when the tool is selected should return a single string.\n",
"\n",
@@ -28,11 +28,13 @@
},
"outputs": [],
"source": [
"import logging\n",
"logging.basicConfig(level=logging.ERROR)\n",
"# Import things that are needed generically\n",
"from langchain import LLMMathChain, SerpAPIWrapper\n",
"from langchain import LLMMathChain\n",
"from langchain.agents import AgentType, Tool, initialize_agent, tool\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.tools import BaseTool"
"from langchain.tools import BaseTool, DuckDuckGoSearchTool"
]
},
{
@@ -56,14 +58,16 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "f8bc72c2",
"metadata": {},
"source": [
"## Completely New Tools \n",
"## Creating New Tools \n",
"\n",
"First, we show how to create completely new tools from scratch.\n",
"\n",
"There are two ways to do this: either by using the Tool dataclass, or by subclassing the BaseTool class."
"There are two ways to do this: either by passing a function to the Tool dataclass, or by subclassing the BaseTool class."
]
},
{
@@ -84,28 +88,16 @@
"outputs": [],
"source": [
"# Load the tool configs that are needed.\n",
"search = SerpAPIWrapper()\n",
"llm_math_chain = LLMMathChain(llm=llm, verbose=True)\n",
"tools = [\n",
" Tool(\n",
" name = \"Search\",\n",
" func=search.run,\n",
" description=\"useful for when you need to answer questions about current events\"\n",
" ),\n",
" DuckDuckGoSearchTool(),\n",
"]\n",
"# You can also define an args_schema to provide more information about inputs\n",
"from pydantic import BaseModel, Field\n",
"\n",
"class CalculatorInput(BaseModel):\n",
" question: str = Field()\n",
" \n",
"\n",
"tools.append(\n",
" Tool(\n",
" name=\"Calculator\",\n",
" func=llm_math_chain.run,\n",
" description=\"useful for when you need to answer questions about math\",\n",
" args_schema=CalculatorInput\n",
" )\n",
")"
]
@@ -139,22 +131,22 @@
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age\n",
"Action: Search\n",
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mDiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years.\u001b[0m\u001b[32;1m\u001b[1;3mI need to find out Camila Morrone's current age\n",
"\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name first before I can calculate her age raised to the 0.43 power.\n",
"Action: DuckDuckGo Search\n",
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[32;1m\u001b[1;3mNow that I know Leo DiCaprio's girlfriend's name is Camila Morrone and her age is 25, I can calculate her age raised to the 0.43 power.\n",
"Action: Calculator\n",
"Action Input: 25^(0.43)\u001b[0m\n",
"Action Input: 25^0.43\u001b[0m\n",
"\n",
"\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n",
"25^(0.43)\u001b[32;1m\u001b[1;3m```text\n",
"25**(0.43)\n",
"25^0.43\u001b[32;1m\u001b[1;3m```text\n",
"25**0.43\n",
"```\n",
"...numexpr.evaluate(\"25**(0.43)\")...\n",
"...numexpr.evaluate(\"25**0.43\")...\n",
"\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3m3.991298452658078\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer\n",
"Final Answer: 3.991298452658078\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mI now know the final answer.\n",
"Final Answer: Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@@ -162,7 +154,7 @@
{
"data": {
"text/plain": [
"'3.991298452658078'"
"\"Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\""
]
},
"execution_count": 5,
@@ -175,11 +167,15 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "6f12eaf0",
"metadata": {},
"source": [
"### Subclassing the BaseTool class"
"### Subclassing the BaseTool class\n",
"\n",
"\n",
"The `BaseTool` and its child classes all expect a single input string and return a single string."
]
},
{
@@ -191,24 +187,33 @@
},
"outputs": [],
"source": [
"from typing import Type\n",
"from pydantic import Field\n",
"\n",
"class CustomSearchTool(BaseTool):\n",
" name = \"Search\"\n",
" description = \"useful for when you need to answer questions about current events\"\n",
" search_tool: DuckDuckGoSearchTool = Field(default_factory=DuckDuckGoSearchTool)\n",
"\n",
" def _run(self, query: str) -> str:\n",
" \"\"\"Use the tool.\"\"\"\n",
" return search.run(query)\n",
" return self.search_tool.run(query)\n",
" \n",
" async def _arun(self, query: str) -> str:\n",
" \"\"\"Use the tool asynchronously.\"\"\"\n",
" raise NotImplementedError(\"BingSearchRun does not support async\")\n",
" \n",
" raise NotImplementedError(\"BingSearchRun does not support async\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "0bb44724",
"metadata": {},
"outputs": [],
"source": [
"\n",
"class CustomCalculatorTool(BaseTool):\n",
" name = \"Calculator\"\n",
" description = \"useful for when you need to answer questions about math\"\n",
" args_schema: Type[BaseModel] = CalculatorInput\n",
"\n",
" def _run(self, query: str) -> str:\n",
" \"\"\"Use the tool.\"\"\"\n",
@@ -221,7 +226,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "3318a46f",
"metadata": {
"tags": []
@@ -233,7 +238,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"id": "ee2d0f3a",
"metadata": {
"tags": []
@@ -245,7 +250,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"id": "6a2cebbf",
"metadata": {
"tags": []
@@ -260,20 +265,22 @@
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age\n",
"Action: Search\n",
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mDiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years.\u001b[0m\u001b[32;1m\u001b[1;3mI need to find out Camila Morrone's current age\n",
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[32;1m\u001b[1;3mI need to find out Camila Morrone's age\n",
"Action: Search\n",
"Action Input: \"Camila Morrone age\"\u001b[0m\u001b[32;1m\u001b[1;3mI need to calculate Camila Morrone's age raised to the 0.43 power\n",
"Action: Calculator\n",
"Action Input: 25^(0.43)\u001b[0m\n",
"Action Input: 24^0.43\u001b[0m\n",
"\n",
"\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n",
"25^(0.43)\u001b[32;1m\u001b[1;3m```text\n",
"25**(0.43)\n",
"24^0.43\u001b[32;1m\u001b[1;3m```text\n",
"24 ** 0.43\n",
"```\n",
"...numexpr.evaluate(\"25**(0.43)\")...\n",
"...numexpr.evaluate(\"24 ** 0.43\")...\n",
"\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3m3.991298452658078\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3m3.9218486893172186\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer\n",
"Final Answer: 3.991298452658078\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mI now know the final answer\n",
"Final Answer: 3.9218486893172186\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@@ -281,10 +288,10 @@
{
"data": {
"text/plain": [
"'3.991298452658078'"
"'3.9218486893172186'"
]
},
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
@@ -305,7 +312,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"id": "8f15307d",
"metadata": {
"tags": []
@@ -322,7 +329,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"id": "0a23b91b",
"metadata": {
"tags": []
@@ -331,10 +338,10 @@
{
"data": {
"text/plain": [
"Tool(name='search_api', description='search_api(query: str) -> str - Searches the API for the query.', args_schema=<class 'pydantic.main.SearchApi'>, return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x12748c4c0>, func=<function search_api at 0x16bd664c0>, coroutine=None)"
"Tool(name='search_api', description='search_api(query: str) -> str - Searches the API for the query.', return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x1032c4a30>, args_schema=<class 'langchain.tools.base.StringSchema'>, func=<function search_api at 0x1180fe040>, coroutine=None)"
]
},
"execution_count": 11,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
@@ -353,7 +360,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"id": "28cdf04d",
"metadata": {
"tags": []
@@ -368,62 +375,17 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"id": "1085a4bd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tool(name='search', description='search(query: str) -> str - Searches the API for the query.', args_schema=<class 'pydantic.main.SearchApi'>, return_direct=True, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x12748c4c0>, func=<function search_api at 0x16bd66310>, coroutine=None)"
"Tool(name='search', description='search(query: str) -> str - Searches the API for the query.', return_direct=True, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x1032c4a30>, args_schema=<class 'langchain.tools.base.StringSchema'>, func=<function search_api at 0x1180fe5e0>, coroutine=None)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"search_api"
]
},
{
"cell_type": "markdown",
"id": "de34a6a3",
"metadata": {},
"source": [
"You can also provide `args_schema` to provide more information about the argument"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "f3a5c106",
"metadata": {},
"outputs": [],
"source": [
"class SearchInput(BaseModel):\n",
" query: str = Field(description=\"should be a search query\")\n",
" \n",
"@tool(\"search\", return_direct=True, args_schema=SearchInput)\n",
"def search_api(query: str) -> str:\n",
" \"\"\"Searches the API for the query.\"\"\"\n",
" return \"Results\""
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "7914ba6b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Tool(name='search', description='search(query: str) -> str - Searches the API for the query.', args_schema=<class '__main__.SearchInput'>, return_direct=True, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x12748c4c0>, func=<function search_api at 0x16bcf0ee0>, coroutine=None)"
]
},
"execution_count": 15,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
@@ -444,7 +406,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 15,
"id": "79213f40",
"metadata": {},
"outputs": [],
@@ -454,17 +416,17 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 16,
"id": "e1067dcb",
"metadata": {},
"outputs": [],
"source": [
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm)"
"tools = load_tools([\"ddg-search\", \"llm-math\"], llm=llm)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 17,
"id": "6c66ffe8",
"metadata": {},
"outputs": [],
@@ -474,7 +436,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 18,
"id": "f45b5bc3",
"metadata": {},
"outputs": [],
@@ -484,7 +446,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 19,
"id": "565e2b9b",
"metadata": {},
"outputs": [
@@ -495,12 +457,14 @@
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age.\n",
"\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name first before I can calculate her age raised to the 0.43 power.\n",
"Action: Google Search\n",
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mI draw the lime at going to get a Mohawk, though.\" DiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years. He's since been linked to another famous supermodel Gigi Hadid.\u001b[0m\u001b[32;1m\u001b[1;3mNow I need to find out Camila Morrone's current age.\n",
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[32;1m\u001b[1;3mI need to refine my search query to find out Camila Morrone's age.\n",
"Action: Google Search\n",
"Action Input: \"Camila Morrone age\"\u001b[0m\u001b[32;1m\u001b[1;3mNow that I know Camila Morrone's age, I can calculate her age raised to the 0.43 power.\n",
"Action: Calculator\n",
"Action Input: 25^0.43\u001b[0m\u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer.\n",
"Final Answer: Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\u001b[0m\n",
"Action Input: 25^(0.43)\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer.\n",
"Final Answer: 3.991298452658078\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@@ -508,10 +472,10 @@
{
"data": {
"text/plain": [
"\"Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\""
"'3.991298452658078'"
]
},
"execution_count": 18,
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
@@ -537,7 +501,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 20,
"id": "3450512e",
"metadata": {},
"outputs": [],
@@ -546,14 +510,10 @@
"from langchain.agents import initialize_agent, Tool\n",
"from langchain.agents import AgentType\n",
"from langchain.llms import OpenAI\n",
"from langchain import LLMMathChain, SerpAPIWrapper\n",
"search = SerpAPIWrapper()\n",
"from langchain import LLMMathChain\n",
"from langchain.tools import DuckDuckGoSearchTool\n",
"tools = [\n",
" Tool(\n",
" name = \"Search\",\n",
" func=search.run,\n",
" description=\"useful for when you need to answer questions about current events\"\n",
" ),\n",
" DuckDuckGoSearchTool(),\n",
" Tool(\n",
" name=\"Music Search\",\n",
" func=lambda x: \"'All I Want For Christmas Is You' by Mariah Carey.\", #Mock Function\n",
@@ -566,7 +526,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 21,
"id": "4b9a7849",
"metadata": {},
"outputs": [
@@ -579,7 +539,7 @@
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I should use a music search engine to find the answer\n",
"Action: Music Search\n",
"Action Input: most famous song of christmas\u001b[0m\u001b[33;1m\u001b[1;3m'All I Want For Christmas Is You' by Mariah Carey.\u001b[0m\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Action Input: most famous song of christmas\u001b[0m\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: 'All I Want For Christmas Is You' by Mariah Carey.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
@@ -591,7 +551,7 @@
"\"'All I Want For Christmas Is You' by Mariah Carey.\""
]
},
"execution_count": 20,
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
@@ -611,7 +571,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 22,
"id": "3bb6185f",
"metadata": {},
"outputs": [],
@@ -629,7 +589,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 23,
"id": "113ddb84",
"metadata": {},
"outputs": [],
@@ -640,7 +600,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 24,
"id": "582439a6",
"metadata": {
"tags": []
@@ -655,7 +615,7 @@
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to calculate this\n",
"Action: Calculator\n",
"Action Input: 2**.12\u001b[0m\u001b[36;1m\u001b[1;3mAnswer: 1.086734862526058\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\n",
"Action Input: 2**.12\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@@ -666,7 +626,7 @@
"'Answer: 1.086734862526058'"
]
},
"execution_count": 23,
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
@@ -674,153 +634,6 @@
"source": [
"agent.run(\"whats 2**.12\")"
]
},
{
"cell_type": "markdown",
"id": "8aa3c353-bd89-467c-9c27-b83a90cd4daa",
"metadata": {},
"source": [
"## Multi-argument tools\n",
"\n",
"Many functions expect structured inputs. These can also be supported using the Tool decorator or by directly subclassing `BaseTool`! We have to modify the LLM's OutputParser to map its string output to a dictionary to pass to the action, however."
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "537bc628",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from typing import Optional, Union\n",
"\n",
"@tool\n",
"def custom_search(k: int, query: str, other_arg: Optional[str] = None):\n",
" \"\"\"The custom search function.\"\"\"\n",
" return f\"Here are the results for the custom search: k={k}, query={query}, other_arg={other_arg}\""
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "d5c992cf-776a-40cd-a6c4-e7cf65ea709e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import re\n",
"from langchain.schema import (\n",
" AgentAction,\n",
" AgentFinish,\n",
")\n",
"from langchain.agents import AgentOutputParser\n",
"\n",
"# We will add a custom parser to map the arguments to a dictionary\n",
"class CustomOutputParser(AgentOutputParser):\n",
" \n",
" def parse_tool_input(self, action_input: str) -> dict:\n",
" # Regex pattern to match arguments and their values\n",
" pattern = r\"(\\w+)\\s*=\\s*(None|\\\"[^\\\"]*\\\"|\\d+)\"\n",
" matches = re.findall(pattern, action_input)\n",
" \n",
" if not matches:\n",
" raise ValueError(f\"Could not parse action input: `{action_input}`\")\n",
"\n",
" # Create a dictionary with the parsed arguments and their values\n",
" parsed_input = {}\n",
" for arg, value in matches:\n",
" if value == \"None\":\n",
" parsed_value = None\n",
" elif value.isdigit():\n",
" parsed_value = int(value)\n",
" else:\n",
" parsed_value = value.strip('\"')\n",
" parsed_input[arg] = parsed_value\n",
"\n",
" return parsed_input\n",
" \n",
" def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:\n",
" # Check if agent should finish\n",
" if \"Final Answer:\" in llm_output:\n",
" return AgentFinish(\n",
" # Return values is generally always a dictionary with a single `output` key\n",
" # It is not recommended to try anything else at the moment :)\n",
" return_values={\"output\": llm_output.split(\"Final Answer:\")[-1].strip()},\n",
" log=llm_output,\n",
" )\n",
" # Parse out the action and action input\n",
" regex = r\"Action\\s*\\d*\\s*:(.*?)\\nAction\\s*\\d*\\s*Input\\s*\\d*\\s*:[\\s]*(.*)\"\n",
" match = re.search(regex, llm_output, re.DOTALL)\n",
" if not match:\n",
" raise ValueError(f\"Could not parse LLM output: `{llm_output}`\")\n",
" action = match.group(1).strip()\n",
" action_input = match.group(2)\n",
" tool_input = self.parse_tool_input(action_input)\n",
" # Return the action and action \n",
" return AgentAction(tool=action, tool_input=tool_input, log=llm_output)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "68269547-1482-4138-a6ea-58f00b4a9548",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"llm = OpenAI(temperature=0)\n",
"agent = initialize_agent([custom_search], llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, agent_kwargs={\"output_parser\": CustomOutputParser()})"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "0947835a-691c-4f51-b8f4-6744e0e48ab1",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to use a search function to find the answer\n",
"Action: custom_search\n",
"Action Input: k=1, query=\"me\"\u001b[0m\u001b[36;1m\u001b[1;3mHere are the results for the custom search: k=1, query=me, other_arg=None\u001b[0m\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: The results of the custom search for k=1, query=me, other_arg=None.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'The results of the custom search for k=1, query=me, other_arg=None.'"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.run(\"Search for me and tell me whatever it says\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "caf39c66-102b-42c1-baf2-777a49886ce4",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@@ -839,7 +652,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.8.16"
},
"vscode": {
"interpreter": {

View File

@@ -1,14 +1,12 @@
"""Interface for tools."""
from functools import partial
from inspect import signature
from typing import Any, Awaitable, Callable, Optional, Type, Union
from typing import Any, Awaitable, Callable, Optional, Union
from pydantic import BaseModel, validate_arguments, validator
from pydantic import validator
from langchain.tools.base import (
BaseTool,
create_schema_from_function,
get_filtered_args,
)
@@ -28,22 +26,14 @@ class Tool(BaseTool):
raise ValueError("Partial functions not yet supported in tools.")
return func
@property
def args(self) -> dict:
if self.args_schema is not None:
return self.args_schema.schema()["properties"]
else:
inferred_model = validate_arguments(self.func).model # type: ignore
return get_filtered_args(inferred_model, self.func)
def _run(self, *args: Any, **kwargs: Any) -> str:
def _run(self, tool_input: str) -> str:
"""Use the tool."""
return self.func(*args, **kwargs)
return self.func(tool_input)
async def _arun(self, *args: Any, **kwargs: Any) -> str:
async def _arun(self, tool_input: str) -> str:
"""Use the tool asynchronously."""
if self.coroutine:
return await self.coroutine(*args, **kwargs)
return await self.coroutine(tool_input)
raise NotImplementedError("Tool does not support async")
# TODO: this is for backwards compatibility, remove in future
@@ -74,8 +64,6 @@ class InvalidTool(BaseTool):
def tool(
*args: Union[str, Callable],
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
) -> Callable:
"""Make tools out of functions, can be used with or without arguments.
@@ -83,10 +71,6 @@ def tool(
*args: The arguments to the tool.
return_direct: Whether to return directly from the tool rather
than continuing the agent loop.
args_schema: optional argument schema for user to specify
infer_schema: Whether to infer the schema of the arguments from
the function's signature. This also makes the resultant tool
accept a dictionary input to its `run()` function.
Requires:
- Function must be of type (str) -> str
@@ -112,13 +96,9 @@ def tool(
# Description example:
# search_api(query: str) - Searches the API for the query.
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
_args_schema = args_schema
if _args_schema is None and infer_schema:
_args_schema = create_schema_from_function(f"{tool_name}Schema", func)
tool_ = Tool(
name=tool_name,
func=func,
args_schema=_args_schema,
description=description,
return_direct=return_direct,
)

View File

@@ -2,17 +2,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import (
Any,
Dict,
Generic,
List,
NamedTuple,
Optional,
Sequence,
TypeVar,
Union,
)
from typing import Any, Dict, Generic, List, NamedTuple, Optional, Sequence, TypeVar
from pydantic import BaseModel, Extra, Field, root_validator
@@ -41,7 +31,7 @@ class AgentAction(NamedTuple):
"""Agent's action to take."""
tool: str
tool_input: Union[str, dict]
tool_input: str
log: str
@@ -401,8 +391,6 @@ class OutputParserException(Exception):
errors will be raised.
"""
pass
class BaseDocumentTransformer(ABC):
"""Base interface for transforming documents."""

View File

@@ -7,6 +7,11 @@ from langchain.tools.ifttt import IFTTTWebhook
from langchain.tools.openapi.utils.api_models import APIOperation
from langchain.tools.openapi.utils.openapi_utils import OpenAPISpec
from langchain.tools.plugin import AIPluginTool
from langchain.tools.structured import (
BaseStructuredTool,
StructuredTool,
structured_tool,
)
__all__ = [
"BaseTool",
@@ -16,4 +21,7 @@ __all__ = [
"APIOperation",
"GooglePlacesTool",
"DuckDuckGoSearchTool",
"BaseStructuredTool",
"StructuredTool",
"structured_tool",
]

View File

@@ -2,157 +2,37 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from inspect import signature
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
from typing import Any, Dict, Optional, Type, Union
from pydantic import (
BaseModel,
Extra,
Field,
create_model,
validate_arguments,
validator,
)
from pydantic.main import ModelMetaclass
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
from langchain.tools.structured import AbstractStructuredTool
def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]:
# For backwards compatability, if run_input is a string,
# pass as a positional argument.
if isinstance(run_input, str):
return (run_input,), {}
else:
return [], run_input
class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
class ToolMetaclass(ModelMetaclass):
"""Metaclass for BaseTool to ensure the provided args_schema
doesn't silently ignored."""
def __new__(
cls: Type[ToolMetaclass], name: str, bases: Tuple[Type, ...], dct: dict
) -> ToolMetaclass:
"""Create the definition of the new tool class."""
schema_type: Optional[Type[BaseModel]] = dct.get("args_schema")
if schema_type is not None:
schema_annotations = dct.get("__annotations__", {})
args_schema_type = schema_annotations.get("args_schema", None)
if args_schema_type is None or args_schema_type == BaseModel:
# Throw errors for common mis-annotations.
# TODO: Use get_args / get_origin and fully
# specify valid annotations.
typehint_mandate = """
class ChildTool(BaseTool):
...
args_schema: Type[BaseModel] = SchemaClass
..."""
raise SchemaAnnotationError(
f"Tool definition for {name} must include valid type annotations"
f" for argument 'args_schema' to behave as expected.\n"
f"Expected annotation of 'Type[BaseModel]'"
f" but got '{args_schema_type}'.\n"
f"Expected class looks like:\n"
f"{typehint_mandate}"
)
# Pass through to Pydantic's metaclass
return super().__new__(cls, name, bases, dct)
def _create_subset_model(
name: str, model: BaseModel, field_names: list
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
fields = {
field_name: (
model.__fields__[field_name].type_,
model.__fields__[field_name].default,
)
for field_name in field_names
if field_name in model.__fields__
}
return create_model(name, **fields) # type: ignore
def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
"""Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters
return {k: schema[k] for k in valid_keys}
def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature."""
inferred_model = validate_arguments(func).model # type: ignore
# Pydantic adds placeholder virtual fields we need to strip
filtered_args = get_filtered_args(inferred_model, func)
return _create_subset_model(
f"{model_name}Schema", inferred_model, list(filtered_args)
)
class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
class BaseTool(ABC, AbstractStructuredTool[str, str]):
"""Interface LangChain tools must implement."""
name: str
description: str
args_schema: Optional[Type[BaseModel]] = None
"""Pydantic model class to validate and parse the tool's input arguments."""
return_direct: bool = False
verbose: bool = False
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
args_schema: Type[str] = str # :meta private:
class Config:
"""Configuration for this pydantic object."""
def _wrap_input(self, tool_input: Union[str, Dict]) -> Dict:
"""Wrap the tool's input into a pydantic model."""
if isinstance(tool_input, Dict):
return tool_input
return {"tool_input": tool_input}
extra = Extra.forbid
arbitrary_types_allowed = True
def _parse_input(self, input_: Dict) -> str:
"""Prepare the args and kwargs for the tool."""
return next(iter(input_.values()))
@property
def args(self) -> dict:
if self.args_schema is not None:
return self.args_schema.schema()["properties"]
else:
inferred_model = validate_arguments(self._run).model # type: ignore
return get_filtered_args(inferred_model, self._run)
def _parse_input(
self,
tool_input: Union[str, Dict],
) -> None:
"""Convert tool input to pydantic model."""
input_args = self.args_schema
if isinstance(tool_input, str):
if input_args is not None:
key_ = next(iter(input_args.__fields__.keys()))
input_args.validate({key_: tool_input})
else:
if input_args is not None:
input_args.validate(tool_input)
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as callback manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
def args(self) -> Dict:
"""Return the JSON schema for the tool's args."""
return {"properties": {"tool_input": {"type": "string"}}}
@abstractmethod
def _run(self, *args: Any, **kwargs: Any) -> str:
def _run(self, tool_input: str) -> str:
"""Use the tool."""
@abstractmethod
async def _arun(self, *args: Any, **kwargs: Any) -> str:
async def _arun(self, tool_input: str) -> str:
"""Use the tool asynchronously."""
def run(
@@ -163,29 +43,9 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
color: Optional[str] = "green",
**kwargs: Any,
) -> str:
"""Run the tool."""
self._parse_input(tool_input)
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
verbose_ = self.verbose
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
verbose=verbose_,
color=start_color,
**kwargs,
)
try:
tool_args, tool_kwargs = _to_args_and_kwargs(tool_input)
observation = self._run(*tool_args, **tool_kwargs)
except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_)
raise e
self.callback_manager.on_tool_end(
observation, verbose=verbose_, color=color, name=self.name, **kwargs
)
return observation
"""Use the tool."""
wrapped_input = self._wrap_input(tool_input)
return super().run(wrapped_input, verbose, start_color, color, **kwargs)
async def arun(
self,
@@ -195,48 +55,9 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
color: Optional[str] = "green",
**kwargs: Any,
) -> str:
"""Run the tool asynchronously."""
self._parse_input(tool_input)
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
verbose_ = self.verbose
if self.callback_manager.is_async:
await self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
verbose=verbose_,
color=start_color,
**kwargs,
)
else:
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
verbose=verbose_,
color=start_color,
**kwargs,
)
try:
# We then call the tool on the tool input to get an observation
args, kwargs = _to_args_and_kwargs(tool_input)
observation = await self._arun(*args, **kwargs)
except (Exception, KeyboardInterrupt) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_tool_error(e, verbose=verbose_)
else:
self.callback_manager.on_tool_error(e, verbose=verbose_)
raise e
if self.callback_manager.is_async:
await self.callback_manager.on_tool_end(
observation, verbose=verbose_, color=color, name=self.name, **kwargs
)
else:
self.callback_manager.on_tool_end(
observation, verbose=verbose_, color=color, name=self.name, **kwargs
)
return observation
"""Use the tool asynchronously."""
wrapped_input = self._wrap_input(tool_input)
return await super().arun(wrapped_input, verbose, start_color, color, **kwargs)
def __call__(self, tool_input: str) -> str:
"""Make tool callable."""
def __call__(self, tool_input: Union[str, Dict]) -> str:
return self.run(tool_input)

View File

@@ -1,21 +1,12 @@
from pathlib import Path
from typing import Optional, Type
from pydantic import BaseModel, Field
from typing import Optional
from langchain.tools.base import BaseTool
from langchain.tools.file_management.utils import get_validated_relative_path
class ReadFileInput(BaseModel):
"""Input for ReadFileTool."""
file_path: str = Field(..., description="name of file")
class ReadFileTool(BaseTool):
name: str = "read_file"
args_schema: Type[BaseModel] = ReadFileInput
description: str = "Read file from disk"
root_dir: Optional[str] = None
"""Directory to read file from.
@@ -35,6 +26,6 @@ class ReadFileTool(BaseTool):
except Exception as e:
return "Error: " + str(e)
async def _arun(self, tool_input: str) -> str:
async def _arun(self, file_path: str) -> str:
# TODO: Add aiofiles method
raise NotImplementedError

View File

@@ -3,8 +3,8 @@ from typing import Optional, Type
from pydantic import BaseModel, Field
from langchain.tools.base import BaseTool
from langchain.tools.file_management.utils import get_validated_relative_path
from langchain.tools.structured import BaseStructuredTool
class WriteFileInput(BaseModel):
@@ -14,29 +14,29 @@ class WriteFileInput(BaseModel):
text: str = Field(..., description="text to write to file")
class WriteFileTool(BaseTool):
class WriteFileTool(BaseStructuredTool[str, WriteFileInput]):
name: str = "write_file"
args_schema: Type[BaseModel] = WriteFileInput
args_schema: Type[WriteFileInput] = WriteFileInput
description: str = "Write file to disk"
root_dir: Optional[str] = None
"""Directory to write file to.
If specified, raises an error for file_paths oustide root_dir."""
def _run(self, file_path: str, text: str) -> str:
def _run(self, tool_input: WriteFileInput) -> str:
write_path = (
get_validated_relative_path(Path(self.root_dir), file_path)
get_validated_relative_path(Path(self.root_dir), tool_input.file_path)
if self.root_dir
else Path(file_path)
else Path(tool_input.file_path)
)
try:
write_path.parent.mkdir(exist_ok=True, parents=False)
with write_path.open("w", encoding="utf-8") as f:
f.write(text)
return f"File written successfully to {file_path}."
f.write(tool_input.text)
return f"File written successfully to {tool_input.file_path}."
except Exception as e:
return "Error: " + str(e)
async def _arun(self, file_path: str, text: str) -> str:
async def _arun(self, tool_input: WriteFileInput) -> str:
# TODO: Add aiofiles method
raise NotImplementedError

View File

@@ -127,11 +127,11 @@ class ListPowerBITool(BaseTool):
arbitrary_types_allowed = True
def _run(self, *args: Any, **kwargs: Any) -> str:
def _run(self, tool_input: str = "") -> str:
"""Get the names of the tables."""
return ", ".join(self.powerbi.get_table_names())
async def _arun(self, *args: Any, **kwargs: Any) -> str:
async def _arun(self, tool_input: str = "") -> str:
"""Get the names of the tables."""
return ", ".join(self.powerbi.get_table_names())

View File

@@ -0,0 +1,360 @@
from __future__ import annotations
import logging
from abc import abstractmethod
from inspect import Parameter, signature
from typing import (
Any,
Awaitable,
Callable,
Dict,
Generic,
Optional,
Type,
TypeVar,
Union,
)
from pydantic import BaseModel, Extra, Field, create_model, validate_arguments
from pydantic.generics import GenericModel
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
logger = logging.getLogger(__name__)
OUTPUT_T = TypeVar("OUTPUT_T")
RUN_T = TypeVar("RUN_T", bound=Union[str, BaseModel])
class AbstractStructuredTool(
GenericModel,
Generic[OUTPUT_T, RUN_T],
BaseModel,
):
"""Parent class for all structured tools."""
name: str
description: str
return_direct: bool = False
verbose: bool = False
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
args_schema: Type[RUN_T] # :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
@abstractmethod
def args(self) -> Dict:
"""Return the JSON schema for the tool's args."""
def _get_verbosity(
self,
verbose: Optional[bool] = None,
) -> bool:
"""Return the verbosity of the tool run."""
return verbose if (not self.verbose and verbose is not None) else self.verbose
@staticmethod
async def _async_or_sync_call(
method: Callable, *args: Any, is_async: bool, **kwargs: Any
) -> Any:
"""Run the callback manager method asynchronously or synchronously."""
if is_async:
return await method(*args, **kwargs)
else:
return method(*args, **kwargs)
@abstractmethod
def _parse_input(self, tool_input: dict) -> RUN_T:
"""Parse the tool input into the object for _run."""
@abstractmethod
def _run(self, tool_input: RUN_T) -> OUTPUT_T:
"""Use the tool."""
@abstractmethod
async def _arun(self, tool_input: RUN_T) -> OUTPUT_T:
"""Use the tool asynchronously."""
def run(
self,
tool_input: dict,
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
**kwargs: Any,
) -> OUTPUT_T:
"""Run the tool."""
parsed_input = self._parse_input(tool_input)
verbose_ = self._get_verbosity(verbose)
self.callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
str(tool_input),
verbose=verbose_,
color=start_color,
**kwargs,
)
try:
observation = self._run(parsed_input)
except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_)
raise e
self.callback_manager.on_tool_end(
str(observation), verbose=verbose_, color=color, name=self.name, **kwargs
)
return observation
async def arun(
self,
tool_input: dict,
verbose: Optional[bool] = None,
start_color: Optional[str] = "green",
color: Optional[str] = "green",
**kwargs: Any,
) -> OUTPUT_T:
"""Run the tool asynchronously."""
parsed_input = self._parse_input(tool_input)
verbose_ = self._get_verbosity(verbose)
await self._async_or_sync_call(
self.callback_manager.on_tool_start,
{"name": self.name, "description": self.description},
str(tool_input),
verbose=verbose_,
color=start_color,
is_async=self.callback_manager.is_async,
**kwargs,
)
try:
observation = await self._arun(parsed_input)
except (Exception, KeyboardInterrupt) as e:
await self._async_or_sync_call(
self.callback_manager.on_tool_error,
e,
verbose=verbose_,
is_async=self.callback_manager.is_async,
)
raise e
await self._async_or_sync_call(
self.callback_manager.on_tool_end,
str(observation),
verbose=verbose_,
color=color,
is_async=self.callback_manager.is_async,
**kwargs,
)
return observation
def __call__(self, tool_input: dict) -> OUTPUT_T:
"""Make tool callable."""
return self.run(tool_input)
def _create_subset_model(
name: str, model: BaseModel, field_names: list
) -> Type[BaseModel]:
"""Create a pydantic model with only a subset of model's fields."""
fields = {
field_name: (
model.__fields__[field_name].type_,
model.__fields__[field_name].default,
)
for field_name in field_names
if field_name in model.__fields__
}
return create_model(name, **fields) # type: ignore
def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
"""Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters
return {k: schema[k] for k in valid_keys}
def _warn_args_kwargs(func: Callable) -> None:
# Check if the function has *args or **kwargs.
# Tools don't interact well with these.
sig = signature(func)
for param in sig.parameters.values():
if param.kind == Parameter.VAR_POSITIONAL:
logger.warning(f"{func.__name__} uses *args, which are not well supported.")
elif param.kind == Parameter.VAR_KEYWORD:
logger.warning(
f"{func.__name__} uses **kwargs, which are not well supported."
)
def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature."""
_warn_args_kwargs(func)
inferred_model = validate_arguments(func).model # type: ignore
# Pydantic adds placeholder virtual fields we need to strip
filtered_args = get_filtered_args(inferred_model, func)
return _create_subset_model(
f"{model_name}Schema", inferred_model, list(filtered_args)
)
STRUCTURED_RUN_T = TypeVar("STRUCTURED_RUN_T", bound=BaseModel)
class BaseStructuredTool(
AbstractStructuredTool[OUTPUT_T, STRUCTURED_RUN_T],
Generic[OUTPUT_T, STRUCTURED_RUN_T],
):
"""The structured tool that requires _run classes to take in a base model."""
args_schema: Type[STRUCTURED_RUN_T] # :meta private:
def _parse_input(self, tool_input: dict) -> STRUCTURED_RUN_T:
return self.args_schema.parse_obj(tool_input)
@property
def args(self) -> Dict:
"""Return the JSON schema for the tool's args."""
schema = self.args_schema.schema()
result = {"properties": schema["properties"]}
if "definitions" in schema:
result["definitions"] = schema["definitions"]
return result
class StructuredTool(BaseStructuredTool[Any, BaseModel]):
"""StructuredTool that takes in function or coroutine directly."""
func: Callable[..., Any]
"""The function to run when the tool is called."""
coroutine: Optional[Callable[..., Awaitable[Any]]] = None
"""The asynchronous version of the function."""
args_schema: Type[BaseModel] # :meta private:
@property
def args(self) -> dict:
"""The JSON Schema arguments for the tool."""
return self.args_schema.schema()["properties"]
def _run(self, tool_input: BaseModel) -> Any:
"""Use the tool."""
parsed_dict = tool_input.dict()
parsed_dict = {k: getattr(tool_input, k) for k in parsed_dict.keys()}
return self.func(**parsed_dict)
async def _arun(self, tool_input: BaseModel) -> Any:
"""Use the tool asynchronously."""
if self.coroutine:
parsed_dict = tool_input.dict()
parsed_dict = {k: getattr(tool_input, k) for k in parsed_dict.keys()}
return await self.coroutine(**parsed_dict)
raise NotImplementedError(f"StructuredTool {self.name} does not support async")
@classmethod
def from_function(
cls,
func: Callable[..., Any],
coroutine: Optional[Callable[..., Awaitable[Any]]] = None,
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
name: Optional[str] = None,
description: Optional[str] = None,
) -> "StructuredTool":
"""Make tools out of functions, can be used with or without arguments.
Args:
func: The function to run when the tool is called.
coroutine: The asynchronous version of the function.
return_direct: Whether to return directly from the tool rather
than continuing the agent loop.
args_schema: optional argument schema for user to specify
infer_schema: Whether to infer the schema of the arguments from
the function's signature. This also makes the resultant tool
accept a dictionary input to its `run()` function.
name: The name of the tool. Defaults to the function name.
description: The description of the tool. Defaults to the function
docstring.
"""
description = func.__doc__ or description
if description is None or not description.strip():
raise ValueError(
f"Function {func.__name__} must have a docstring, or set description."
)
name = name or func.__name__
_args_schema = args_schema
if _args_schema is None and infer_schema:
_args_schema = create_schema_from_function(f"{name}Schema", func)
description = f"{name}{signature(func)} - {description}"
return cls(
name=name,
func=func,
coroutine=coroutine,
return_direct=return_direct,
args_schema=_args_schema,
description=description,
)
def structured_tool(
*args: Union[str, Callable],
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
) -> Callable:
"""Make tools out of functions, can be used with or without arguments.
Args:
*args: The arguments to the tool.
return_direct: Whether to return directly from the tool rather
than continuing the agent loop.
args_schema: Optional argument schema for user to specify. If
none, will infer the schema from the function's signature.
Requires:
- Function must be of type (str) -> str
- Function must have a docstring
Examples:
.. code-block:: python
@tool
def search_api(query: str) -> str:
# Searches the API for the query.
return
@tool("search", return_direct=True)
def search_api(query: str) -> str:
# Searches the API for the query.
return
"""
def _make_with_name(tool_name: str) -> Callable:
def _make_tool(func: Callable) -> StructuredTool:
return StructuredTool.from_function(
name=tool_name,
func=func,
args_schema=args_schema,
return_direct=return_direct,
)
return _make_tool
if len(args) == 1 and isinstance(args[0], str):
# if the argument is a string, then we use the string as the tool name
# Example usage: @tool("search", return_direct=True)
return _make_with_name(args[0])
elif len(args) == 1 and callable(args[0]):
# if the argument is a function, then we use the function name as the tool name
# Example usage: @tool
return _make_with_name(args[0].__name__)(args[0])
elif len(args) == 0:
# if there are no arguments, then we use the function name as the tool name
# Example usage: @tool(return_direct=True)
def _partial(func: Callable[[str], str]) -> BaseStructuredTool:
return _make_with_name(func.__name__)(func)
return _partial
else:
raise ValueError("Too many arguments for tool decorator")

View File

@@ -1,9 +1,8 @@
"""Util that calls OpenWeatherMap using PyOWM."""
from typing import Any, Dict, Optional
from pydantic import Extra, root_validator
from pydantic import BaseModel, Extra, root_validator
from langchain.tools.base import BaseModel
from langchain.utils import get_from_dict_or_env

View File

@@ -1,14 +1,11 @@
"""Test tool utils."""
from datetime import datetime
from functools import partial
from typing import Optional, Type, Union
import pydantic
from typing import Optional
import pytest
from pydantic import BaseModel
from langchain.agents.tools import Tool, tool
from langchain.tools.base import BaseTool, SchemaAnnotationError
from langchain.tools.base import BaseTool
def test_unnamed_decorator() -> None:
@@ -23,169 +20,7 @@ def test_unnamed_decorator() -> None:
assert search_api.name == "search_api"
assert not search_api.return_direct
assert search_api("test") == "API result"
class _MockSchema(BaseModel):
arg1: int
arg2: bool
arg3: Optional[dict] = None
class _MockStructuredTool(BaseTool):
name = "structured_api"
args_schema: Type[BaseModel] = _MockSchema
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
raise NotImplementedError
def test_structured_args() -> None:
"""Test functionality with structured arguments."""
structured_api = _MockStructuredTool()
assert isinstance(structured_api, BaseTool)
assert structured_api.name == "structured_api"
expected_result = "1 True {'foo': 'bar'}"
args = {"arg1": 1, "arg2": True, "arg3": {"foo": "bar"}}
assert structured_api.run(args) == expected_result
def test_unannotated_base_tool_raises_error() -> None:
"""Test that a BaseTool without type hints raises an exception.""" ""
with pytest.raises(SchemaAnnotationError):
class _UnAnnotatedTool(BaseTool):
name = "structured_api"
# This would silently be ignored without the custom metaclass
args_schema = _MockSchema
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
raise NotImplementedError
def test_misannotated_base_tool_raises_error() -> None:
"""Test that a BaseTool with the incorrrect typehint raises an exception.""" ""
with pytest.raises(SchemaAnnotationError):
class _MisAnnotatedTool(BaseTool):
name = "structured_api"
# This would silently be ignored without the custom metaclass
args_schema: BaseModel = _MockSchema # type: ignore
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
raise NotImplementedError
def test_forward_ref_annotated_base_tool_accepted() -> None:
"""Test that a using forward ref annotation syntax is accepted.""" ""
class _ForwardRefAnnotatedTool(BaseTool):
name = "structured_api"
args_schema: "Type[BaseModel]" = _MockSchema
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
raise NotImplementedError
def test_subclass_annotated_base_tool_accepted() -> None:
"""Test BaseTool child w/ custom schema isn't overwritten."""
class _ForwardRefAnnotatedTool(BaseTool):
name = "structured_api"
args_schema: Type[_MockSchema] = _MockSchema
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
raise NotImplementedError
assert issubclass(_ForwardRefAnnotatedTool, BaseTool)
tool = _ForwardRefAnnotatedTool()
assert tool.args_schema == _MockSchema
def test_decorator_with_specified_schema() -> None:
"""Test that manually specified schemata are passed through to the tool."""
@tool(args_schema=_MockSchema)
def tool_func(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
"""Return the arguments directly."""
return f"{arg1} {arg2} {arg3}"
assert isinstance(tool_func, Tool)
assert tool_func.args_schema == _MockSchema
def test_decorated_function_schema_equivalent() -> None:
"""Test that a BaseTool without a schema meets expectations."""
@tool
def structured_tool_input(
arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
"""Return the arguments directly."""
return f"{arg1} {arg2} {arg3}"
assert isinstance(structured_tool_input, Tool)
assert (
structured_tool_input.args_schema.schema()["properties"]
== _MockSchema.schema()["properties"]
== structured_tool_input.args
)
def test_structured_args_decorator_no_infer_schema() -> None:
"""Test functionality with structured arguments parsed as a decorator."""
@tool(infer_schema=False)
def structured_tool_input(
arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None
) -> str:
"""Return the arguments directly."""
return f"{arg1}, {arg2}, {opt_arg}"
assert isinstance(structured_tool_input, Tool)
assert structured_tool_input.name == "structured_tool_input"
args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}}
expected_result = "1, 0.001, {'foo': 'bar'}"
assert structured_tool_input.run(args) == expected_result
def test_structured_single_str_decorator_no_infer_schema() -> None:
"""Test functionality with structured arguments parsed as a decorator."""
@tool(infer_schema=False)
def unstructured_tool_input(tool_input: str) -> str:
"""Return the arguments directly."""
return f"{tool_input}"
assert isinstance(unstructured_tool_input, Tool)
assert unstructured_tool_input.args_schema is None
assert search_api.args_schema == str
def test_base_tool_inheritance_base_schema() -> None:
@@ -202,8 +37,8 @@ def test_base_tool_inheritance_base_schema() -> None:
raise NotImplementedError
simple_tool = _MockSimpleTool()
assert simple_tool.args_schema is None
expected_args = {"tool_input": {"title": "Tool Input", "type": "string"}}
assert simple_tool.args_schema == str
expected_args = {"properties": {"tool_input": {"type": "string"}}}
assert simple_tool.args == expected_args
@@ -215,56 +50,11 @@ def test_tool_lambda_args_schema() -> None:
description="A tool",
func=lambda tool_input: tool_input,
)
assert tool.args_schema is None
expected_args = {"tool_input": {"title": "Tool Input"}}
assert tool.args_schema == str
expected_args = {"properties": {"tool_input": {"type": "string"}}}
assert tool.args == expected_args
def test_tool_lambda_multi_args_schema() -> None:
"""Test args schema inference when the tool argument is a lambda function."""
tool = Tool(
name="tool",
description="A tool",
func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore
)
assert tool.args_schema is None
expected_args = {
"tool_input": {"title": "Tool Input"},
"other_arg": {"title": "Other Arg"},
}
assert tool.args == expected_args
def test_tool_partial_function_args_schema() -> None:
"""Test args schema inference when the tool argument is a partial function."""
def func(tool_input: str, other_arg: str) -> str:
return tool_input + other_arg
with pytest.raises(pydantic.error_wrappers.ValidationError):
# We don't yet support args_schema inference for partial functions
# so want to make sure we proactively raise an error
Tool(
name="tool",
description="A tool",
func=partial(func, other_arg="foo"),
)
def test_empty_args_decorator() -> None:
"""Test inferred schema of decorated fn with no args."""
@tool
def empty_tool_input() -> str:
"""Return a constant."""
return "the empty result"
assert isinstance(empty_tool_input, Tool)
assert empty_tool_input.name == "empty_tool_input"
assert empty_tool_input.args == {}
assert empty_tool_input.run({}) == "the empty result"
def test_named_tool_decorator() -> None:
"""Test functionality when arguments are provided as input to decorator."""
@@ -304,32 +94,17 @@ def test_unnamed_tool_decorator_return_direct() -> None:
assert search_api.return_direct
def test_tool_with_kwargs() -> None:
"""Test functionality when only return direct is provided."""
def test_base_tool_decorator_multiple_args() -> None:
"""Test the schema that's generated is still a simple string."""
@tool(return_direct=True)
def search_api(
arg_1: float,
ping: str = "hi",
) -> str:
def some_tool(query: str, foo: int = 3, bar: Optional[dict] = None) -> str:
"""Search the API for the query."""
return f"arg_1={arg_1}, ping={ping}"
return f"{query} {foo} {bar}"
assert isinstance(search_api, Tool)
result = search_api.run(
tool_input={
"arg_1": 3.2,
"ping": "pong",
}
)
assert result == "arg_1=3.2, ping=pong"
result = search_api.run(
tool_input={
"arg_1": 3.2,
}
)
assert result == "arg_1=3.2, ping=hi"
assert isinstance(some_tool, Tool)
assert some_tool.name == "some_tool"
assert some_tool.run("foo") == "foo 3 None"
def test_missing_docstring() -> None:

View File

@@ -0,0 +1,188 @@
import logging
from functools import partial
from typing import Any, Optional, Type
import pydantic
import pytest
from pydantic import BaseModel
from langchain.tools.structured import (
BaseStructuredTool,
StructuredTool,
structured_tool,
)
class _MockSchema(BaseModel):
arg1: int
arg2: bool
arg3: Optional[dict] = None
class _MockStructuredTool(BaseStructuredTool):
name = "structured_api"
args_schema: Type[BaseModel] = _MockSchema
description = "A Structured Tool"
def _run(self, tool_input: _MockSchema) -> str:
return f"{tool_input.arg1} {tool_input.arg2} {tool_input.arg3}"
async def _arun(self, tool_input: _MockSchema) -> str:
raise NotImplementedError
def test_structured_args() -> None:
"""Test functionality with structured arguments."""
structured_api = _MockStructuredTool()
assert isinstance(structured_api, BaseStructuredTool)
assert structured_api.name == "structured_api"
expected_result = "1 True {'foo': 'bar'}"
args = {"arg1": 1, "arg2": True, "arg3": {"foo": "bar"}}
assert structured_api.run(args) == expected_result
def test_decorator_with_specified_schema() -> None:
"""Test that manually specified schemata are passed through to the tool."""
@structured_tool(args_schema=_MockSchema)
def tool_func(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
"""Return the arguments directly."""
return f"{arg1} {arg2} {arg3}"
assert isinstance(tool_func, StructuredTool)
assert tool_func.args_schema == _MockSchema
def test_decorated_function_schema_equivalent() -> None:
"""Test that a BaseTool without a schema meets expectations."""
@structured_tool
def structured_tool_input(
arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
"""Return the arguments directly."""
return f"{arg1} {arg2} {arg3}"
assert isinstance(structured_tool_input, StructuredTool)
assert (
structured_tool_input.args_schema.schema()["properties"]
== _MockSchema.schema()["properties"]
== structured_tool_input.args
)
def test_tool_lambda_multi_args_schema() -> None:
"""Test args schema inference when the tool argument is a lambda function."""
tool = StructuredTool.from_function(
func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore
name="tool",
description="A tool",
)
assert set(tool.args_schema.schema()["properties"]) == {"tool_input", "other_arg"}
expected_args = {
"tool_input": {"title": "Tool Input"},
"other_arg": {"title": "Other Arg"},
}
assert tool.args == expected_args
def test_tool_partial_function_args_schema() -> None:
"""Test args schema inference when the tool argument is a partial function."""
def func(tool_input: str, other_arg: str) -> str:
return tool_input + other_arg
with pytest.raises(pydantic.error_wrappers.ValidationError):
# We don't yet support args_schema inference for partial functions
# so want to make sure we proactively raise an error
StructuredTool(
name="tool",
description="A tool",
func=partial(func, other_arg="foo"),
)
def test_tool_with_kwargs() -> None:
"""Test functionality when only return direct is provided."""
@structured_tool(return_direct=True)
def search_api(
arg_1: float,
ping: str = "hi",
) -> str:
"""Search the API for the query."""
return f"arg_1={arg_1}, ping={ping}"
assert isinstance(search_api, StructuredTool)
result = search_api.run(
tool_input={
"arg_1": 3.2,
"ping": "pong",
}
)
assert result == "arg_1=3.2, ping=pong"
result = search_api.run(
tool_input={
"arg_1": 3.2,
}
)
assert result == "arg_1=3.2, ping=hi"
def test_empty_args_decorator() -> None:
"""Test inferred schema of decorated fn with no args."""
@structured_tool
def empty_tool_input() -> str:
"""Return a constant."""
return "the empty result"
assert isinstance(empty_tool_input, StructuredTool)
assert empty_tool_input.name == "empty_tool_input"
assert empty_tool_input.args == {}
assert empty_tool_input.run({}) == "the empty result"
def test_nested_pydantic_args() -> None:
"""Test inferred schema when args are nested pydantic models."""
# This is a pattern that is common with FastAPI methods.
# If we only parse a dict input but pass the dict
# to the function, we are limited only to primitive types
# in general.
class SomeNestedInput(BaseModel):
arg2: str
class SomeInput(BaseModel):
arg1: int
arg2: SomeNestedInput
@structured_tool
def nested_tool(some_input: SomeInput) -> dict:
"""Return a constant."""
return some_input.dict()
assert isinstance(nested_tool, StructuredTool)
assert nested_tool.name == "nested_tool"
input_ = {"some_input": {"arg1": 1, "arg2": {"arg2": "foo"}}}
assert nested_tool.run(input_) == input_["some_input"]
def test_warning_on_args_kwargs(caplog: pytest.LogCaptureFixture) -> None:
"""Test inferred schema when args are nested pydantic models."""
with caplog.at_level(logging.WARNING):
@structured_tool
def anything_goes(*foo: Any, **bar: Any) -> str:
"""Return a constant."""
return str(foo) + "|" + str(bar)
# Check if the expected warning message was logged
assert any(
"anything_goes uses *args" in record.message for record in caplog.records
)
assert any(
"anything_goes uses **kwargs" in record.message for record in caplog.records
)