mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-04 16:20:16 +00:00
Compare commits
2 Commits
bagatur/rf
...
vwp/struct
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5b72451f9d | ||
|
|
5c0d36a0db |
@@ -1,6 +1,7 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "5436020b",
|
||||
"metadata": {},
|
||||
@@ -12,7 +13,6 @@
|
||||
"- name (str), is required and must be unique within a set of tools provided to an agent\n",
|
||||
"- description (str), is optional but recommended, as it is used by an agent to determine tool use\n",
|
||||
"- return_direct (bool), defaults to False\n",
|
||||
"- args_schema (Pydantic BaseModel), is optional but recommended, can be used to provide more information or validation for expected parameters.\n",
|
||||
"\n",
|
||||
"The function that should be called when the tool is selected should return a single string.\n",
|
||||
"\n",
|
||||
@@ -28,11 +28,13 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import logging\n",
|
||||
"logging.basicConfig(level=logging.ERROR)\n",
|
||||
"# Import things that are needed generically\n",
|
||||
"from langchain import LLMMathChain, SerpAPIWrapper\n",
|
||||
"from langchain import LLMMathChain\n",
|
||||
"from langchain.agents import AgentType, Tool, initialize_agent, tool\n",
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.tools import BaseTool"
|
||||
"from langchain.tools import BaseTool, DuckDuckGoSearchTool"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -56,14 +58,16 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "f8bc72c2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Completely New Tools \n",
|
||||
"## Creating New Tools \n",
|
||||
"\n",
|
||||
"First, we show how to create completely new tools from scratch.\n",
|
||||
"\n",
|
||||
"There are two ways to do this: either by using the Tool dataclass, or by subclassing the BaseTool class."
|
||||
"There are two ways to do this: either by passing a function to the Tool dataclass, or by subclassing the BaseTool class."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -84,28 +88,16 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load the tool configs that are needed.\n",
|
||||
"search = SerpAPIWrapper()\n",
|
||||
"llm_math_chain = LLMMathChain(llm=llm, verbose=True)\n",
|
||||
"tools = [\n",
|
||||
" Tool(\n",
|
||||
" name = \"Search\",\n",
|
||||
" func=search.run,\n",
|
||||
" description=\"useful for when you need to answer questions about current events\"\n",
|
||||
" ),\n",
|
||||
" DuckDuckGoSearchTool(),\n",
|
||||
"]\n",
|
||||
"# You can also define an args_schema to provide more information about inputs\n",
|
||||
"from pydantic import BaseModel, Field\n",
|
||||
"\n",
|
||||
"class CalculatorInput(BaseModel):\n",
|
||||
" question: str = Field()\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"tools.append(\n",
|
||||
" Tool(\n",
|
||||
" name=\"Calculator\",\n",
|
||||
" func=llm_math_chain.run,\n",
|
||||
" description=\"useful for when you need to answer questions about math\",\n",
|
||||
" args_schema=CalculatorInput\n",
|
||||
" )\n",
|
||||
")"
|
||||
]
|
||||
@@ -139,22 +131,22 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mDiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years.\u001b[0m\u001b[32;1m\u001b[1;3mI need to find out Camila Morrone's current age\n",
|
||||
"\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name first before I can calculate her age raised to the 0.43 power.\n",
|
||||
"Action: DuckDuckGo Search\n",
|
||||
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[32;1m\u001b[1;3mNow that I know Leo DiCaprio's girlfriend's name is Camila Morrone and her age is 25, I can calculate her age raised to the 0.43 power.\n",
|
||||
"Action: Calculator\n",
|
||||
"Action Input: 25^(0.43)\u001b[0m\n",
|
||||
"Action Input: 25^0.43\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n",
|
||||
"25^(0.43)\u001b[32;1m\u001b[1;3m```text\n",
|
||||
"25**(0.43)\n",
|
||||
"25^0.43\u001b[32;1m\u001b[1;3m```text\n",
|
||||
"25**0.43\n",
|
||||
"```\n",
|
||||
"...numexpr.evaluate(\"25**(0.43)\")...\n",
|
||||
"...numexpr.evaluate(\"25**0.43\")...\n",
|
||||
"\u001b[0m\n",
|
||||
"Answer: \u001b[33;1m\u001b[1;3m3.991298452658078\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer\n",
|
||||
"Final Answer: 3.991298452658078\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3mI now know the final answer.\n",
|
||||
"Final Answer: Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
@@ -162,7 +154,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'3.991298452658078'"
|
||||
"\"Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\""
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
@@ -175,11 +167,15 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "6f12eaf0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Subclassing the BaseTool class"
|
||||
"### Subclassing the BaseTool class\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"The `BaseTool` and its child classes all expect a single input string and return a single string."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -191,24 +187,33 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Type\n",
|
||||
"from pydantic import Field\n",
|
||||
"\n",
|
||||
"class CustomSearchTool(BaseTool):\n",
|
||||
" name = \"Search\"\n",
|
||||
" description = \"useful for when you need to answer questions about current events\"\n",
|
||||
" search_tool: DuckDuckGoSearchTool = Field(default_factory=DuckDuckGoSearchTool)\n",
|
||||
"\n",
|
||||
" def _run(self, query: str) -> str:\n",
|
||||
" \"\"\"Use the tool.\"\"\"\n",
|
||||
" return search.run(query)\n",
|
||||
" return self.search_tool.run(query)\n",
|
||||
" \n",
|
||||
" async def _arun(self, query: str) -> str:\n",
|
||||
" \"\"\"Use the tool asynchronously.\"\"\"\n",
|
||||
" raise NotImplementedError(\"BingSearchRun does not support async\")\n",
|
||||
" \n",
|
||||
" raise NotImplementedError(\"BingSearchRun does not support async\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "0bb44724",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"class CustomCalculatorTool(BaseTool):\n",
|
||||
" name = \"Calculator\"\n",
|
||||
" description = \"useful for when you need to answer questions about math\"\n",
|
||||
" args_schema: Type[BaseModel] = CalculatorInput\n",
|
||||
"\n",
|
||||
" def _run(self, query: str) -> str:\n",
|
||||
" \"\"\"Use the tool.\"\"\"\n",
|
||||
@@ -221,7 +226,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 8,
|
||||
"id": "3318a46f",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -233,7 +238,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 9,
|
||||
"id": "ee2d0f3a",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -245,7 +250,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 10,
|
||||
"id": "6a2cebbf",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -260,20 +265,22 @@
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mDiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years.\u001b[0m\u001b[32;1m\u001b[1;3mI need to find out Camila Morrone's current age\n",
|
||||
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[32;1m\u001b[1;3mI need to find out Camila Morrone's age\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: \"Camila Morrone age\"\u001b[0m\u001b[32;1m\u001b[1;3mI need to calculate Camila Morrone's age raised to the 0.43 power\n",
|
||||
"Action: Calculator\n",
|
||||
"Action Input: 25^(0.43)\u001b[0m\n",
|
||||
"Action Input: 24^0.43\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n",
|
||||
"25^(0.43)\u001b[32;1m\u001b[1;3m```text\n",
|
||||
"25**(0.43)\n",
|
||||
"24^0.43\u001b[32;1m\u001b[1;3m```text\n",
|
||||
"24 ** 0.43\n",
|
||||
"```\n",
|
||||
"...numexpr.evaluate(\"25**(0.43)\")...\n",
|
||||
"...numexpr.evaluate(\"24 ** 0.43\")...\n",
|
||||
"\u001b[0m\n",
|
||||
"Answer: \u001b[33;1m\u001b[1;3m3.991298452658078\u001b[0m\n",
|
||||
"Answer: \u001b[33;1m\u001b[1;3m3.9218486893172186\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer\n",
|
||||
"Final Answer: 3.991298452658078\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3mI now know the final answer\n",
|
||||
"Final Answer: 3.9218486893172186\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
@@ -281,10 +288,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'3.991298452658078'"
|
||||
"'3.9218486893172186'"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -305,7 +312,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 11,
|
||||
"id": "8f15307d",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -322,7 +329,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 12,
|
||||
"id": "0a23b91b",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -331,10 +338,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Tool(name='search_api', description='search_api(query: str) -> str - Searches the API for the query.', args_schema=<class 'pydantic.main.SearchApi'>, return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x12748c4c0>, func=<function search_api at 0x16bd664c0>, coroutine=None)"
|
||||
"Tool(name='search_api', description='search_api(query: str) -> str - Searches the API for the query.', return_direct=False, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x1032c4a30>, args_schema=<class 'langchain.tools.base.StringSchema'>, func=<function search_api at 0x1180fe040>, coroutine=None)"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -353,7 +360,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 13,
|
||||
"id": "28cdf04d",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -368,62 +375,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 14,
|
||||
"id": "1085a4bd",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Tool(name='search', description='search(query: str) -> str - Searches the API for the query.', args_schema=<class 'pydantic.main.SearchApi'>, return_direct=True, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x12748c4c0>, func=<function search_api at 0x16bd66310>, coroutine=None)"
|
||||
"Tool(name='search', description='search(query: str) -> str - Searches the API for the query.', return_direct=True, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x1032c4a30>, args_schema=<class 'langchain.tools.base.StringSchema'>, func=<function search_api at 0x1180fe5e0>, coroutine=None)"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"search_api"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "de34a6a3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can also provide `args_schema` to provide more information about the argument"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "f3a5c106",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class SearchInput(BaseModel):\n",
|
||||
" query: str = Field(description=\"should be a search query\")\n",
|
||||
" \n",
|
||||
"@tool(\"search\", return_direct=True, args_schema=SearchInput)\n",
|
||||
"def search_api(query: str) -> str:\n",
|
||||
" \"\"\"Searches the API for the query.\"\"\"\n",
|
||||
" return \"Results\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "7914ba6b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Tool(name='search', description='search(query: str) -> str - Searches the API for the query.', args_schema=<class '__main__.SearchInput'>, return_direct=True, verbose=False, callback_manager=<langchain.callbacks.shared.SharedCallbackManager object at 0x12748c4c0>, func=<function search_api at 0x16bcf0ee0>, coroutine=None)"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -444,7 +406,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 15,
|
||||
"id": "79213f40",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -454,17 +416,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 16,
|
||||
"id": "e1067dcb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm)"
|
||||
"tools = load_tools([\"ddg-search\", \"llm-math\"], llm=llm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 17,
|
||||
"id": "6c66ffe8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -474,7 +436,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 18,
|
||||
"id": "f45b5bc3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -484,7 +446,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": 19,
|
||||
"id": "565e2b9b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -495,12 +457,14 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age.\n",
|
||||
"\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name first before I can calculate her age raised to the 0.43 power.\n",
|
||||
"Action: Google Search\n",
|
||||
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mI draw the lime at going to get a Mohawk, though.\" DiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years. He's since been linked to another famous supermodel – Gigi Hadid.\u001b[0m\u001b[32;1m\u001b[1;3mNow I need to find out Camila Morrone's current age.\n",
|
||||
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[32;1m\u001b[1;3mI need to refine my search query to find out Camila Morrone's age.\n",
|
||||
"Action: Google Search\n",
|
||||
"Action Input: \"Camila Morrone age\"\u001b[0m\u001b[32;1m\u001b[1;3mNow that I know Camila Morrone's age, I can calculate her age raised to the 0.43 power.\n",
|
||||
"Action: Calculator\n",
|
||||
"Action Input: 25^0.43\u001b[0m\u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer.\n",
|
||||
"Final Answer: Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\u001b[0m\n",
|
||||
"Action Input: 25^(0.43)\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer.\n",
|
||||
"Final Answer: 3.991298452658078\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
@@ -508,10 +472,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\"Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\""
|
||||
"'3.991298452658078'"
|
||||
]
|
||||
},
|
||||
"execution_count": 18,
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -537,7 +501,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": 20,
|
||||
"id": "3450512e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -546,14 +510,10 @@
|
||||
"from langchain.agents import initialize_agent, Tool\n",
|
||||
"from langchain.agents import AgentType\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain import LLMMathChain, SerpAPIWrapper\n",
|
||||
"search = SerpAPIWrapper()\n",
|
||||
"from langchain import LLMMathChain\n",
|
||||
"from langchain.tools import DuckDuckGoSearchTool\n",
|
||||
"tools = [\n",
|
||||
" Tool(\n",
|
||||
" name = \"Search\",\n",
|
||||
" func=search.run,\n",
|
||||
" description=\"useful for when you need to answer questions about current events\"\n",
|
||||
" ),\n",
|
||||
" DuckDuckGoSearchTool(),\n",
|
||||
" Tool(\n",
|
||||
" name=\"Music Search\",\n",
|
||||
" func=lambda x: \"'All I Want For Christmas Is You' by Mariah Carey.\", #Mock Function\n",
|
||||
@@ -566,7 +526,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 21,
|
||||
"id": "4b9a7849",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -579,7 +539,7 @@
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3m I should use a music search engine to find the answer\n",
|
||||
"Action: Music Search\n",
|
||||
"Action Input: most famous song of christmas\u001b[0m\u001b[33;1m\u001b[1;3m'All I Want For Christmas Is You' by Mariah Carey.\u001b[0m\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
||||
"Action Input: most famous song of christmas\u001b[0m\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
||||
"Final Answer: 'All I Want For Christmas Is You' by Mariah Carey.\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
@@ -591,7 +551,7 @@
|
||||
"\"'All I Want For Christmas Is You' by Mariah Carey.\""
|
||||
]
|
||||
},
|
||||
"execution_count": 20,
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -611,7 +571,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 22,
|
||||
"id": "3bb6185f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -629,7 +589,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 23,
|
||||
"id": "113ddb84",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -640,7 +600,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 24,
|
||||
"id": "582439a6",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -655,7 +615,7 @@
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3m I need to calculate this\n",
|
||||
"Action: Calculator\n",
|
||||
"Action Input: 2**.12\u001b[0m\u001b[36;1m\u001b[1;3mAnswer: 1.086734862526058\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\n",
|
||||
"Action Input: 2**.12\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
@@ -666,7 +626,7 @@
|
||||
"'Answer: 1.086734862526058'"
|
||||
]
|
||||
},
|
||||
"execution_count": 23,
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -674,153 +634,6 @@
|
||||
"source": [
|
||||
"agent.run(\"whats 2**.12\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8aa3c353-bd89-467c-9c27-b83a90cd4daa",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Multi-argument tools\n",
|
||||
"\n",
|
||||
"Many functions expect structured inputs. These can also be supported using the Tool decorator or by directly subclassing `BaseTool`! We have to modify the LLM's OutputParser to map its string output to a dictionary to pass to the action, however."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"id": "537bc628",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Optional, Union\n",
|
||||
"\n",
|
||||
"@tool\n",
|
||||
"def custom_search(k: int, query: str, other_arg: Optional[str] = None):\n",
|
||||
" \"\"\"The custom search function.\"\"\"\n",
|
||||
" return f\"Here are the results for the custom search: k={k}, query={query}, other_arg={other_arg}\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"id": "d5c992cf-776a-40cd-a6c4-e7cf65ea709e",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import re\n",
|
||||
"from langchain.schema import (\n",
|
||||
" AgentAction,\n",
|
||||
" AgentFinish,\n",
|
||||
")\n",
|
||||
"from langchain.agents import AgentOutputParser\n",
|
||||
"\n",
|
||||
"# We will add a custom parser to map the arguments to a dictionary\n",
|
||||
"class CustomOutputParser(AgentOutputParser):\n",
|
||||
" \n",
|
||||
" def parse_tool_input(self, action_input: str) -> dict:\n",
|
||||
" # Regex pattern to match arguments and their values\n",
|
||||
" pattern = r\"(\\w+)\\s*=\\s*(None|\\\"[^\\\"]*\\\"|\\d+)\"\n",
|
||||
" matches = re.findall(pattern, action_input)\n",
|
||||
" \n",
|
||||
" if not matches:\n",
|
||||
" raise ValueError(f\"Could not parse action input: `{action_input}`\")\n",
|
||||
"\n",
|
||||
" # Create a dictionary with the parsed arguments and their values\n",
|
||||
" parsed_input = {}\n",
|
||||
" for arg, value in matches:\n",
|
||||
" if value == \"None\":\n",
|
||||
" parsed_value = None\n",
|
||||
" elif value.isdigit():\n",
|
||||
" parsed_value = int(value)\n",
|
||||
" else:\n",
|
||||
" parsed_value = value.strip('\"')\n",
|
||||
" parsed_input[arg] = parsed_value\n",
|
||||
"\n",
|
||||
" return parsed_input\n",
|
||||
" \n",
|
||||
" def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:\n",
|
||||
" # Check if agent should finish\n",
|
||||
" if \"Final Answer:\" in llm_output:\n",
|
||||
" return AgentFinish(\n",
|
||||
" # Return values is generally always a dictionary with a single `output` key\n",
|
||||
" # It is not recommended to try anything else at the moment :)\n",
|
||||
" return_values={\"output\": llm_output.split(\"Final Answer:\")[-1].strip()},\n",
|
||||
" log=llm_output,\n",
|
||||
" )\n",
|
||||
" # Parse out the action and action input\n",
|
||||
" regex = r\"Action\\s*\\d*\\s*:(.*?)\\nAction\\s*\\d*\\s*Input\\s*\\d*\\s*:[\\s]*(.*)\"\n",
|
||||
" match = re.search(regex, llm_output, re.DOTALL)\n",
|
||||
" if not match:\n",
|
||||
" raise ValueError(f\"Could not parse LLM output: `{llm_output}`\")\n",
|
||||
" action = match.group(1).strip()\n",
|
||||
" action_input = match.group(2)\n",
|
||||
" tool_input = self.parse_tool_input(action_input)\n",
|
||||
" # Return the action and action \n",
|
||||
" return AgentAction(tool=action, tool_input=tool_input, log=llm_output)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"id": "68269547-1482-4138-a6ea-58f00b4a9548",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"agent = initialize_agent([custom_search], llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, agent_kwargs={\"output_parser\": CustomOutputParser()})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"id": "0947835a-691c-4f51-b8f4-6744e0e48ab1",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3m I need to use a search function to find the answer\n",
|
||||
"Action: custom_search\n",
|
||||
"Action Input: k=1, query=\"me\"\u001b[0m\u001b[36;1m\u001b[1;3mHere are the results for the custom search: k=1, query=me, other_arg=None\u001b[0m\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
||||
"Final Answer: The results of the custom search for k=1, query=me, other_arg=None.\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'The results of the custom search for k=1, query=me, other_arg=None.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agent.run(\"Search for me and tell me whatever it says\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "caf39c66-102b-42c1-baf2-777a49886ce4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@@ -839,7 +652,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
"version": "3.8.16"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
"""Interface for tools."""
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import Any, Awaitable, Callable, Optional, Type, Union
|
||||
from typing import Any, Awaitable, Callable, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, validate_arguments, validator
|
||||
from pydantic import validator
|
||||
|
||||
from langchain.tools.base import (
|
||||
BaseTool,
|
||||
create_schema_from_function,
|
||||
get_filtered_args,
|
||||
)
|
||||
|
||||
|
||||
@@ -28,22 +26,14 @@ class Tool(BaseTool):
|
||||
raise ValueError("Partial functions not yet supported in tools.")
|
||||
return func
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
if self.args_schema is not None:
|
||||
return self.args_schema.schema()["properties"]
|
||||
else:
|
||||
inferred_model = validate_arguments(self.func).model # type: ignore
|
||||
return get_filtered_args(inferred_model, self.func)
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
||||
def _run(self, tool_input: str) -> str:
|
||||
"""Use the tool."""
|
||||
return self.func(*args, **kwargs)
|
||||
return self.func(tool_input)
|
||||
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> str:
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
if self.coroutine:
|
||||
return await self.coroutine(*args, **kwargs)
|
||||
return await self.coroutine(tool_input)
|
||||
raise NotImplementedError("Tool does not support async")
|
||||
|
||||
# TODO: this is for backwards compatibility, remove in future
|
||||
@@ -74,8 +64,6 @@ class InvalidTool(BaseTool):
|
||||
def tool(
|
||||
*args: Union[str, Callable],
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
infer_schema: bool = True,
|
||||
) -> Callable:
|
||||
"""Make tools out of functions, can be used with or without arguments.
|
||||
|
||||
@@ -83,10 +71,6 @@ def tool(
|
||||
*args: The arguments to the tool.
|
||||
return_direct: Whether to return directly from the tool rather
|
||||
than continuing the agent loop.
|
||||
args_schema: optional argument schema for user to specify
|
||||
infer_schema: Whether to infer the schema of the arguments from
|
||||
the function's signature. This also makes the resultant tool
|
||||
accept a dictionary input to its `run()` function.
|
||||
|
||||
Requires:
|
||||
- Function must be of type (str) -> str
|
||||
@@ -112,13 +96,9 @@ def tool(
|
||||
# Description example:
|
||||
# search_api(query: str) - Searches the API for the query.
|
||||
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
|
||||
_args_schema = args_schema
|
||||
if _args_schema is None and infer_schema:
|
||||
_args_schema = create_schema_from_function(f"{tool_name}Schema", func)
|
||||
tool_ = Tool(
|
||||
name=tool_name,
|
||||
func=func,
|
||||
args_schema=_args_schema,
|
||||
description=description,
|
||||
return_direct=return_direct,
|
||||
)
|
||||
|
||||
@@ -2,17 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Dict, Generic, List, NamedTuple, Optional, Sequence, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
@@ -41,7 +31,7 @@ class AgentAction(NamedTuple):
|
||||
"""Agent's action to take."""
|
||||
|
||||
tool: str
|
||||
tool_input: Union[str, dict]
|
||||
tool_input: str
|
||||
log: str
|
||||
|
||||
|
||||
@@ -401,8 +391,6 @@ class OutputParserException(Exception):
|
||||
errors will be raised.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BaseDocumentTransformer(ABC):
|
||||
"""Base interface for transforming documents."""
|
||||
|
||||
@@ -7,6 +7,11 @@ from langchain.tools.ifttt import IFTTTWebhook
|
||||
from langchain.tools.openapi.utils.api_models import APIOperation
|
||||
from langchain.tools.openapi.utils.openapi_utils import OpenAPISpec
|
||||
from langchain.tools.plugin import AIPluginTool
|
||||
from langchain.tools.structured import (
|
||||
BaseStructuredTool,
|
||||
StructuredTool,
|
||||
structured_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseTool",
|
||||
@@ -16,4 +21,7 @@ __all__ = [
|
||||
"APIOperation",
|
||||
"GooglePlacesTool",
|
||||
"DuckDuckGoSearchTool",
|
||||
"BaseStructuredTool",
|
||||
"StructuredTool",
|
||||
"structured_tool",
|
||||
]
|
||||
|
||||
@@ -2,157 +2,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
|
||||
from typing import Any, Dict, Optional, Type, Union
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
Field,
|
||||
create_model,
|
||||
validate_arguments,
|
||||
validator,
|
||||
)
|
||||
from pydantic.main import ModelMetaclass
|
||||
|
||||
from langchain.callbacks import get_callback_manager
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.tools.structured import AbstractStructuredTool
|
||||
|
||||
|
||||
def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]:
|
||||
# For backwards compatability, if run_input is a string,
|
||||
# pass as a positional argument.
|
||||
if isinstance(run_input, str):
|
||||
return (run_input,), {}
|
||||
else:
|
||||
return [], run_input
|
||||
|
||||
|
||||
class SchemaAnnotationError(TypeError):
|
||||
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
|
||||
|
||||
|
||||
class ToolMetaclass(ModelMetaclass):
|
||||
"""Metaclass for BaseTool to ensure the provided args_schema
|
||||
|
||||
doesn't silently ignored."""
|
||||
|
||||
def __new__(
|
||||
cls: Type[ToolMetaclass], name: str, bases: Tuple[Type, ...], dct: dict
|
||||
) -> ToolMetaclass:
|
||||
"""Create the definition of the new tool class."""
|
||||
schema_type: Optional[Type[BaseModel]] = dct.get("args_schema")
|
||||
if schema_type is not None:
|
||||
schema_annotations = dct.get("__annotations__", {})
|
||||
args_schema_type = schema_annotations.get("args_schema", None)
|
||||
if args_schema_type is None or args_schema_type == BaseModel:
|
||||
# Throw errors for common mis-annotations.
|
||||
# TODO: Use get_args / get_origin and fully
|
||||
# specify valid annotations.
|
||||
typehint_mandate = """
|
||||
class ChildTool(BaseTool):
|
||||
...
|
||||
args_schema: Type[BaseModel] = SchemaClass
|
||||
..."""
|
||||
raise SchemaAnnotationError(
|
||||
f"Tool definition for {name} must include valid type annotations"
|
||||
f" for argument 'args_schema' to behave as expected.\n"
|
||||
f"Expected annotation of 'Type[BaseModel]'"
|
||||
f" but got '{args_schema_type}'.\n"
|
||||
f"Expected class looks like:\n"
|
||||
f"{typehint_mandate}"
|
||||
)
|
||||
# Pass through to Pydantic's metaclass
|
||||
return super().__new__(cls, name, bases, dct)
|
||||
|
||||
|
||||
def _create_subset_model(
|
||||
name: str, model: BaseModel, field_names: list
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a pydantic model with only a subset of model's fields."""
|
||||
fields = {
|
||||
field_name: (
|
||||
model.__fields__[field_name].type_,
|
||||
model.__fields__[field_name].default,
|
||||
)
|
||||
for field_name in field_names
|
||||
if field_name in model.__fields__
|
||||
}
|
||||
return create_model(name, **fields) # type: ignore
|
||||
|
||||
|
||||
def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
|
||||
"""Get the arguments from a function's signature."""
|
||||
schema = inferred_model.schema()["properties"]
|
||||
valid_keys = signature(func).parameters
|
||||
return {k: schema[k] for k in valid_keys}
|
||||
|
||||
|
||||
def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
|
||||
"""Create a pydantic schema from a function's signature."""
|
||||
inferred_model = validate_arguments(func).model # type: ignore
|
||||
# Pydantic adds placeholder virtual fields we need to strip
|
||||
filtered_args = get_filtered_args(inferred_model, func)
|
||||
return _create_subset_model(
|
||||
f"{model_name}Schema", inferred_model, list(filtered_args)
|
||||
)
|
||||
|
||||
|
||||
class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
class BaseTool(ABC, AbstractStructuredTool[str, str]):
|
||||
"""Interface LangChain tools must implement."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
args_schema: Optional[Type[BaseModel]] = None
|
||||
"""Pydantic model class to validate and parse the tool's input arguments."""
|
||||
return_direct: bool = False
|
||||
verbose: bool = False
|
||||
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
||||
args_schema: Type[str] = str # :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
def _wrap_input(self, tool_input: Union[str, Dict]) -> Dict:
|
||||
"""Wrap the tool's input into a pydantic model."""
|
||||
if isinstance(tool_input, Dict):
|
||||
return tool_input
|
||||
return {"tool_input": tool_input}
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
def _parse_input(self, input_: Dict) -> str:
|
||||
"""Prepare the args and kwargs for the tool."""
|
||||
return next(iter(input_.values()))
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
if self.args_schema is not None:
|
||||
return self.args_schema.schema()["properties"]
|
||||
else:
|
||||
inferred_model = validate_arguments(self._run).model # type: ignore
|
||||
return get_filtered_args(inferred_model, self._run)
|
||||
|
||||
def _parse_input(
|
||||
self,
|
||||
tool_input: Union[str, Dict],
|
||||
) -> None:
|
||||
"""Convert tool input to pydantic model."""
|
||||
input_args = self.args_schema
|
||||
if isinstance(tool_input, str):
|
||||
if input_args is not None:
|
||||
key_ = next(iter(input_args.__fields__.keys()))
|
||||
input_args.validate({key_: tool_input})
|
||||
else:
|
||||
if input_args is not None:
|
||||
input_args.validate(tool_input)
|
||||
|
||||
@validator("callback_manager", pre=True, always=True)
|
||||
def set_callback_manager(
|
||||
cls, callback_manager: Optional[BaseCallbackManager]
|
||||
) -> BaseCallbackManager:
|
||||
"""If callback manager is None, set it.
|
||||
|
||||
This allows users to pass in None as callback manager, which is a nice UX.
|
||||
"""
|
||||
return callback_manager or get_callback_manager()
|
||||
def args(self) -> Dict:
|
||||
"""Return the JSON schema for the tool's args."""
|
||||
return {"properties": {"tool_input": {"type": "string"}}}
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
||||
def _run(self, tool_input: str) -> str:
|
||||
"""Use the tool."""
|
||||
|
||||
@abstractmethod
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> str:
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
|
||||
def run(
|
||||
@@ -163,29 +43,9 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
color: Optional[str] = "green",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the tool."""
|
||||
self._parse_input(tool_input)
|
||||
if not self.verbose and verbose is not None:
|
||||
verbose_ = verbose
|
||||
else:
|
||||
verbose_ = self.verbose
|
||||
self.callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||
verbose=verbose_,
|
||||
color=start_color,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
tool_args, tool_kwargs = _to_args_and_kwargs(tool_input)
|
||||
observation = self._run(*tool_args, **tool_kwargs)
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
self.callback_manager.on_tool_error(e, verbose=verbose_)
|
||||
raise e
|
||||
self.callback_manager.on_tool_end(
|
||||
observation, verbose=verbose_, color=color, name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
"""Use the tool."""
|
||||
wrapped_input = self._wrap_input(tool_input)
|
||||
return super().run(wrapped_input, verbose, start_color, color, **kwargs)
|
||||
|
||||
async def arun(
|
||||
self,
|
||||
@@ -195,48 +55,9 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
color: Optional[str] = "green",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the tool asynchronously."""
|
||||
self._parse_input(tool_input)
|
||||
if not self.verbose and verbose is not None:
|
||||
verbose_ = verbose
|
||||
else:
|
||||
verbose_ = self.verbose
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||
verbose=verbose_,
|
||||
color=start_color,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||
verbose=verbose_,
|
||||
color=start_color,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
# We then call the tool on the tool input to get an observation
|
||||
args, kwargs = _to_args_and_kwargs(tool_input)
|
||||
observation = await self._arun(*args, **kwargs)
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_tool_error(e, verbose=verbose_)
|
||||
else:
|
||||
self.callback_manager.on_tool_error(e, verbose=verbose_)
|
||||
raise e
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_tool_end(
|
||||
observation, verbose=verbose_, color=color, name=self.name, **kwargs
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_tool_end(
|
||||
observation, verbose=verbose_, color=color, name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
"""Use the tool asynchronously."""
|
||||
wrapped_input = self._wrap_input(tool_input)
|
||||
return await super().arun(wrapped_input, verbose, start_color, color, **kwargs)
|
||||
|
||||
def __call__(self, tool_input: str) -> str:
|
||||
"""Make tool callable."""
|
||||
def __call__(self, tool_input: Union[str, Dict]) -> str:
|
||||
return self.run(tool_input)
|
||||
|
||||
@@ -1,21 +1,12 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.file_management.utils import get_validated_relative_path
|
||||
|
||||
|
||||
class ReadFileInput(BaseModel):
|
||||
"""Input for ReadFileTool."""
|
||||
|
||||
file_path: str = Field(..., description="name of file")
|
||||
|
||||
|
||||
class ReadFileTool(BaseTool):
|
||||
name: str = "read_file"
|
||||
args_schema: Type[BaseModel] = ReadFileInput
|
||||
description: str = "Read file from disk"
|
||||
root_dir: Optional[str] = None
|
||||
"""Directory to read file from.
|
||||
@@ -35,6 +26,6 @@ class ReadFileTool(BaseTool):
|
||||
except Exception as e:
|
||||
return "Error: " + str(e)
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
async def _arun(self, file_path: str) -> str:
|
||||
# TODO: Add aiofiles method
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -3,8 +3,8 @@ from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.file_management.utils import get_validated_relative_path
|
||||
from langchain.tools.structured import BaseStructuredTool
|
||||
|
||||
|
||||
class WriteFileInput(BaseModel):
|
||||
@@ -14,29 +14,29 @@ class WriteFileInput(BaseModel):
|
||||
text: str = Field(..., description="text to write to file")
|
||||
|
||||
|
||||
class WriteFileTool(BaseTool):
|
||||
class WriteFileTool(BaseStructuredTool[str, WriteFileInput]):
|
||||
name: str = "write_file"
|
||||
args_schema: Type[BaseModel] = WriteFileInput
|
||||
args_schema: Type[WriteFileInput] = WriteFileInput
|
||||
description: str = "Write file to disk"
|
||||
root_dir: Optional[str] = None
|
||||
"""Directory to write file to.
|
||||
|
||||
If specified, raises an error for file_paths oustide root_dir."""
|
||||
|
||||
def _run(self, file_path: str, text: str) -> str:
|
||||
def _run(self, tool_input: WriteFileInput) -> str:
|
||||
write_path = (
|
||||
get_validated_relative_path(Path(self.root_dir), file_path)
|
||||
get_validated_relative_path(Path(self.root_dir), tool_input.file_path)
|
||||
if self.root_dir
|
||||
else Path(file_path)
|
||||
else Path(tool_input.file_path)
|
||||
)
|
||||
try:
|
||||
write_path.parent.mkdir(exist_ok=True, parents=False)
|
||||
with write_path.open("w", encoding="utf-8") as f:
|
||||
f.write(text)
|
||||
return f"File written successfully to {file_path}."
|
||||
f.write(tool_input.text)
|
||||
return f"File written successfully to {tool_input.file_path}."
|
||||
except Exception as e:
|
||||
return "Error: " + str(e)
|
||||
|
||||
async def _arun(self, file_path: str, text: str) -> str:
|
||||
async def _arun(self, tool_input: WriteFileInput) -> str:
|
||||
# TODO: Add aiofiles method
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -127,11 +127,11 @@ class ListPowerBITool(BaseTool):
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> str:
|
||||
def _run(self, tool_input: str = "") -> str:
|
||||
"""Get the names of the tables."""
|
||||
return ", ".join(self.powerbi.get_table_names())
|
||||
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> str:
|
||||
async def _arun(self, tool_input: str = "") -> str:
|
||||
"""Get the names of the tables."""
|
||||
return ", ".join(self.powerbi.get_table_names())
|
||||
|
||||
|
||||
360
langchain/tools/structured.py
Normal file
360
langchain/tools/structured.py
Normal file
@@ -0,0 +1,360 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from inspect import Parameter, signature
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, create_model, validate_arguments
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
from langchain.callbacks import get_callback_manager
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OUTPUT_T = TypeVar("OUTPUT_T")
|
||||
RUN_T = TypeVar("RUN_T", bound=Union[str, BaseModel])
|
||||
|
||||
|
||||
class AbstractStructuredTool(
|
||||
GenericModel,
|
||||
Generic[OUTPUT_T, RUN_T],
|
||||
BaseModel,
|
||||
):
|
||||
"""Parent class for all structured tools."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
return_direct: bool = False
|
||||
verbose: bool = False
|
||||
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
||||
args_schema: Type[RUN_T] # :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def args(self) -> Dict:
|
||||
"""Return the JSON schema for the tool's args."""
|
||||
|
||||
def _get_verbosity(
|
||||
self,
|
||||
verbose: Optional[bool] = None,
|
||||
) -> bool:
|
||||
"""Return the verbosity of the tool run."""
|
||||
return verbose if (not self.verbose and verbose is not None) else self.verbose
|
||||
|
||||
@staticmethod
|
||||
async def _async_or_sync_call(
|
||||
method: Callable, *args: Any, is_async: bool, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run the callback manager method asynchronously or synchronously."""
|
||||
if is_async:
|
||||
return await method(*args, **kwargs)
|
||||
else:
|
||||
return method(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _parse_input(self, tool_input: dict) -> RUN_T:
|
||||
"""Parse the tool input into the object for _run."""
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, tool_input: RUN_T) -> OUTPUT_T:
|
||||
"""Use the tool."""
|
||||
|
||||
@abstractmethod
|
||||
async def _arun(self, tool_input: RUN_T) -> OUTPUT_T:
|
||||
"""Use the tool asynchronously."""
|
||||
|
||||
def run(
|
||||
self,
|
||||
tool_input: dict,
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
**kwargs: Any,
|
||||
) -> OUTPUT_T:
|
||||
"""Run the tool."""
|
||||
parsed_input = self._parse_input(tool_input)
|
||||
verbose_ = self._get_verbosity(verbose)
|
||||
self.callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
str(tool_input),
|
||||
verbose=verbose_,
|
||||
color=start_color,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
observation = self._run(parsed_input)
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
self.callback_manager.on_tool_error(e, verbose=verbose_)
|
||||
raise e
|
||||
self.callback_manager.on_tool_end(
|
||||
str(observation), verbose=verbose_, color=color, name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
|
||||
async def arun(
|
||||
self,
|
||||
tool_input: dict,
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
**kwargs: Any,
|
||||
) -> OUTPUT_T:
|
||||
"""Run the tool asynchronously."""
|
||||
parsed_input = self._parse_input(tool_input)
|
||||
verbose_ = self._get_verbosity(verbose)
|
||||
await self._async_or_sync_call(
|
||||
self.callback_manager.on_tool_start,
|
||||
{"name": self.name, "description": self.description},
|
||||
str(tool_input),
|
||||
verbose=verbose_,
|
||||
color=start_color,
|
||||
is_async=self.callback_manager.is_async,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
observation = await self._arun(parsed_input)
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
await self._async_or_sync_call(
|
||||
self.callback_manager.on_tool_error,
|
||||
e,
|
||||
verbose=verbose_,
|
||||
is_async=self.callback_manager.is_async,
|
||||
)
|
||||
raise e
|
||||
await self._async_or_sync_call(
|
||||
self.callback_manager.on_tool_end,
|
||||
str(observation),
|
||||
verbose=verbose_,
|
||||
color=color,
|
||||
is_async=self.callback_manager.is_async,
|
||||
**kwargs,
|
||||
)
|
||||
return observation
|
||||
|
||||
def __call__(self, tool_input: dict) -> OUTPUT_T:
|
||||
"""Make tool callable."""
|
||||
return self.run(tool_input)
|
||||
|
||||
|
||||
def _create_subset_model(
|
||||
name: str, model: BaseModel, field_names: list
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a pydantic model with only a subset of model's fields."""
|
||||
fields = {
|
||||
field_name: (
|
||||
model.__fields__[field_name].type_,
|
||||
model.__fields__[field_name].default,
|
||||
)
|
||||
for field_name in field_names
|
||||
if field_name in model.__fields__
|
||||
}
|
||||
return create_model(name, **fields) # type: ignore
|
||||
|
||||
|
||||
def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
|
||||
"""Get the arguments from a function's signature."""
|
||||
schema = inferred_model.schema()["properties"]
|
||||
valid_keys = signature(func).parameters
|
||||
return {k: schema[k] for k in valid_keys}
|
||||
|
||||
|
||||
def _warn_args_kwargs(func: Callable) -> None:
|
||||
# Check if the function has *args or **kwargs.
|
||||
# Tools don't interact well with these.
|
||||
sig = signature(func)
|
||||
for param in sig.parameters.values():
|
||||
if param.kind == Parameter.VAR_POSITIONAL:
|
||||
logger.warning(f"{func.__name__} uses *args, which are not well supported.")
|
||||
elif param.kind == Parameter.VAR_KEYWORD:
|
||||
logger.warning(
|
||||
f"{func.__name__} uses **kwargs, which are not well supported."
|
||||
)
|
||||
|
||||
|
||||
def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
|
||||
"""Create a pydantic schema from a function's signature."""
|
||||
_warn_args_kwargs(func)
|
||||
inferred_model = validate_arguments(func).model # type: ignore
|
||||
# Pydantic adds placeholder virtual fields we need to strip
|
||||
filtered_args = get_filtered_args(inferred_model, func)
|
||||
return _create_subset_model(
|
||||
f"{model_name}Schema", inferred_model, list(filtered_args)
|
||||
)
|
||||
|
||||
|
||||
STRUCTURED_RUN_T = TypeVar("STRUCTURED_RUN_T", bound=BaseModel)
|
||||
|
||||
|
||||
class BaseStructuredTool(
|
||||
AbstractStructuredTool[OUTPUT_T, STRUCTURED_RUN_T],
|
||||
Generic[OUTPUT_T, STRUCTURED_RUN_T],
|
||||
):
|
||||
"""The structured tool that requires _run classes to take in a base model."""
|
||||
|
||||
args_schema: Type[STRUCTURED_RUN_T] # :meta private:
|
||||
|
||||
def _parse_input(self, tool_input: dict) -> STRUCTURED_RUN_T:
|
||||
return self.args_schema.parse_obj(tool_input)
|
||||
|
||||
@property
|
||||
def args(self) -> Dict:
|
||||
"""Return the JSON schema for the tool's args."""
|
||||
schema = self.args_schema.schema()
|
||||
result = {"properties": schema["properties"]}
|
||||
if "definitions" in schema:
|
||||
result["definitions"] = schema["definitions"]
|
||||
return result
|
||||
|
||||
|
||||
class StructuredTool(BaseStructuredTool[Any, BaseModel]):
|
||||
"""StructuredTool that takes in function or coroutine directly."""
|
||||
|
||||
func: Callable[..., Any]
|
||||
"""The function to run when the tool is called."""
|
||||
coroutine: Optional[Callable[..., Awaitable[Any]]] = None
|
||||
"""The asynchronous version of the function."""
|
||||
args_schema: Type[BaseModel] # :meta private:
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
"""The JSON Schema arguments for the tool."""
|
||||
return self.args_schema.schema()["properties"]
|
||||
|
||||
def _run(self, tool_input: BaseModel) -> Any:
|
||||
"""Use the tool."""
|
||||
parsed_dict = tool_input.dict()
|
||||
parsed_dict = {k: getattr(tool_input, k) for k in parsed_dict.keys()}
|
||||
return self.func(**parsed_dict)
|
||||
|
||||
async def _arun(self, tool_input: BaseModel) -> Any:
|
||||
"""Use the tool asynchronously."""
|
||||
if self.coroutine:
|
||||
parsed_dict = tool_input.dict()
|
||||
parsed_dict = {k: getattr(tool_input, k) for k in parsed_dict.keys()}
|
||||
return await self.coroutine(**parsed_dict)
|
||||
raise NotImplementedError(f"StructuredTool {self.name} does not support async")
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
cls,
|
||||
func: Callable[..., Any],
|
||||
coroutine: Optional[Callable[..., Awaitable[Any]]] = None,
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
infer_schema: bool = True,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
) -> "StructuredTool":
|
||||
"""Make tools out of functions, can be used with or without arguments.
|
||||
|
||||
Args:
|
||||
func: The function to run when the tool is called.
|
||||
coroutine: The asynchronous version of the function.
|
||||
return_direct: Whether to return directly from the tool rather
|
||||
than continuing the agent loop.
|
||||
args_schema: optional argument schema for user to specify
|
||||
infer_schema: Whether to infer the schema of the arguments from
|
||||
the function's signature. This also makes the resultant tool
|
||||
accept a dictionary input to its `run()` function.
|
||||
name: The name of the tool. Defaults to the function name.
|
||||
description: The description of the tool. Defaults to the function
|
||||
docstring.
|
||||
"""
|
||||
description = func.__doc__ or description
|
||||
if description is None or not description.strip():
|
||||
raise ValueError(
|
||||
f"Function {func.__name__} must have a docstring, or set description."
|
||||
)
|
||||
name = name or func.__name__
|
||||
_args_schema = args_schema
|
||||
if _args_schema is None and infer_schema:
|
||||
_args_schema = create_schema_from_function(f"{name}Schema", func)
|
||||
description = f"{name}{signature(func)} - {description}"
|
||||
return cls(
|
||||
name=name,
|
||||
func=func,
|
||||
coroutine=coroutine,
|
||||
return_direct=return_direct,
|
||||
args_schema=_args_schema,
|
||||
description=description,
|
||||
)
|
||||
|
||||
|
||||
def structured_tool(
|
||||
*args: Union[str, Callable],
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
) -> Callable:
|
||||
"""Make tools out of functions, can be used with or without arguments.
|
||||
|
||||
Args:
|
||||
*args: The arguments to the tool.
|
||||
return_direct: Whether to return directly from the tool rather
|
||||
than continuing the agent loop.
|
||||
args_schema: Optional argument schema for user to specify. If
|
||||
none, will infer the schema from the function's signature.
|
||||
|
||||
Requires:
|
||||
- Function must be of type (str) -> str
|
||||
- Function must have a docstring
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
@tool
|
||||
def search_api(query: str) -> str:
|
||||
# Searches the API for the query.
|
||||
return
|
||||
|
||||
@tool("search", return_direct=True)
|
||||
def search_api(query: str) -> str:
|
||||
# Searches the API for the query.
|
||||
return
|
||||
"""
|
||||
|
||||
def _make_with_name(tool_name: str) -> Callable:
|
||||
def _make_tool(func: Callable) -> StructuredTool:
|
||||
return StructuredTool.from_function(
|
||||
name=tool_name,
|
||||
func=func,
|
||||
args_schema=args_schema,
|
||||
return_direct=return_direct,
|
||||
)
|
||||
|
||||
return _make_tool
|
||||
|
||||
if len(args) == 1 and isinstance(args[0], str):
|
||||
# if the argument is a string, then we use the string as the tool name
|
||||
# Example usage: @tool("search", return_direct=True)
|
||||
return _make_with_name(args[0])
|
||||
elif len(args) == 1 and callable(args[0]):
|
||||
# if the argument is a function, then we use the function name as the tool name
|
||||
# Example usage: @tool
|
||||
return _make_with_name(args[0].__name__)(args[0])
|
||||
elif len(args) == 0:
|
||||
# if there are no arguments, then we use the function name as the tool name
|
||||
# Example usage: @tool(return_direct=True)
|
||||
def _partial(func: Callable[[str], str]) -> BaseStructuredTool:
|
||||
return _make_with_name(func.__name__)(func)
|
||||
|
||||
return _partial
|
||||
else:
|
||||
raise ValueError("Too many arguments for tool decorator")
|
||||
@@ -1,9 +1,8 @@
|
||||
"""Util that calls OpenWeatherMap using PyOWM."""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.tools.base import BaseModel
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
"""Test tool utils."""
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Optional, Type, Union
|
||||
|
||||
import pydantic
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.agents.tools import Tool, tool
|
||||
from langchain.tools.base import BaseTool, SchemaAnnotationError
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
def test_unnamed_decorator() -> None:
|
||||
@@ -23,169 +20,7 @@ def test_unnamed_decorator() -> None:
|
||||
assert search_api.name == "search_api"
|
||||
assert not search_api.return_direct
|
||||
assert search_api("test") == "API result"
|
||||
|
||||
|
||||
class _MockSchema(BaseModel):
|
||||
arg1: int
|
||||
arg2: bool
|
||||
arg3: Optional[dict] = None
|
||||
|
||||
|
||||
class _MockStructuredTool(BaseTool):
|
||||
name = "structured_api"
|
||||
args_schema: Type[BaseModel] = _MockSchema
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_structured_args() -> None:
|
||||
"""Test functionality with structured arguments."""
|
||||
structured_api = _MockStructuredTool()
|
||||
assert isinstance(structured_api, BaseTool)
|
||||
assert structured_api.name == "structured_api"
|
||||
expected_result = "1 True {'foo': 'bar'}"
|
||||
args = {"arg1": 1, "arg2": True, "arg3": {"foo": "bar"}}
|
||||
assert structured_api.run(args) == expected_result
|
||||
|
||||
|
||||
def test_unannotated_base_tool_raises_error() -> None:
|
||||
"""Test that a BaseTool without type hints raises an exception.""" ""
|
||||
with pytest.raises(SchemaAnnotationError):
|
||||
|
||||
class _UnAnnotatedTool(BaseTool):
|
||||
name = "structured_api"
|
||||
# This would silently be ignored without the custom metaclass
|
||||
args_schema = _MockSchema
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(
|
||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_misannotated_base_tool_raises_error() -> None:
|
||||
"""Test that a BaseTool with the incorrrect typehint raises an exception.""" ""
|
||||
with pytest.raises(SchemaAnnotationError):
|
||||
|
||||
class _MisAnnotatedTool(BaseTool):
|
||||
name = "structured_api"
|
||||
# This would silently be ignored without the custom metaclass
|
||||
args_schema: BaseModel = _MockSchema # type: ignore
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(
|
||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_forward_ref_annotated_base_tool_accepted() -> None:
|
||||
"""Test that a using forward ref annotation syntax is accepted.""" ""
|
||||
|
||||
class _ForwardRefAnnotatedTool(BaseTool):
|
||||
name = "structured_api"
|
||||
args_schema: "Type[BaseModel]" = _MockSchema
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(
|
||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_subclass_annotated_base_tool_accepted() -> None:
|
||||
"""Test BaseTool child w/ custom schema isn't overwritten."""
|
||||
|
||||
class _ForwardRefAnnotatedTool(BaseTool):
|
||||
name = "structured_api"
|
||||
args_schema: Type[_MockSchema] = _MockSchema
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(
|
||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
assert issubclass(_ForwardRefAnnotatedTool, BaseTool)
|
||||
tool = _ForwardRefAnnotatedTool()
|
||||
assert tool.args_schema == _MockSchema
|
||||
|
||||
|
||||
def test_decorator_with_specified_schema() -> None:
|
||||
"""Test that manually specified schemata are passed through to the tool."""
|
||||
|
||||
@tool(args_schema=_MockSchema)
|
||||
def tool_func(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
"""Return the arguments directly."""
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
assert isinstance(tool_func, Tool)
|
||||
assert tool_func.args_schema == _MockSchema
|
||||
|
||||
|
||||
def test_decorated_function_schema_equivalent() -> None:
|
||||
"""Test that a BaseTool without a schema meets expectations."""
|
||||
|
||||
@tool
|
||||
def structured_tool_input(
|
||||
arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
"""Return the arguments directly."""
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
assert isinstance(structured_tool_input, Tool)
|
||||
assert (
|
||||
structured_tool_input.args_schema.schema()["properties"]
|
||||
== _MockSchema.schema()["properties"]
|
||||
== structured_tool_input.args
|
||||
)
|
||||
|
||||
|
||||
def test_structured_args_decorator_no_infer_schema() -> None:
|
||||
"""Test functionality with structured arguments parsed as a decorator."""
|
||||
|
||||
@tool(infer_schema=False)
|
||||
def structured_tool_input(
|
||||
arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None
|
||||
) -> str:
|
||||
"""Return the arguments directly."""
|
||||
return f"{arg1}, {arg2}, {opt_arg}"
|
||||
|
||||
assert isinstance(structured_tool_input, Tool)
|
||||
assert structured_tool_input.name == "structured_tool_input"
|
||||
args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}}
|
||||
expected_result = "1, 0.001, {'foo': 'bar'}"
|
||||
assert structured_tool_input.run(args) == expected_result
|
||||
|
||||
|
||||
def test_structured_single_str_decorator_no_infer_schema() -> None:
|
||||
"""Test functionality with structured arguments parsed as a decorator."""
|
||||
|
||||
@tool(infer_schema=False)
|
||||
def unstructured_tool_input(tool_input: str) -> str:
|
||||
"""Return the arguments directly."""
|
||||
return f"{tool_input}"
|
||||
|
||||
assert isinstance(unstructured_tool_input, Tool)
|
||||
assert unstructured_tool_input.args_schema is None
|
||||
assert search_api.args_schema == str
|
||||
|
||||
|
||||
def test_base_tool_inheritance_base_schema() -> None:
|
||||
@@ -202,8 +37,8 @@ def test_base_tool_inheritance_base_schema() -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
simple_tool = _MockSimpleTool()
|
||||
assert simple_tool.args_schema is None
|
||||
expected_args = {"tool_input": {"title": "Tool Input", "type": "string"}}
|
||||
assert simple_tool.args_schema == str
|
||||
expected_args = {"properties": {"tool_input": {"type": "string"}}}
|
||||
assert simple_tool.args == expected_args
|
||||
|
||||
|
||||
@@ -215,56 +50,11 @@ def test_tool_lambda_args_schema() -> None:
|
||||
description="A tool",
|
||||
func=lambda tool_input: tool_input,
|
||||
)
|
||||
assert tool.args_schema is None
|
||||
expected_args = {"tool_input": {"title": "Tool Input"}}
|
||||
assert tool.args_schema == str
|
||||
expected_args = {"properties": {"tool_input": {"type": "string"}}}
|
||||
assert tool.args == expected_args
|
||||
|
||||
|
||||
def test_tool_lambda_multi_args_schema() -> None:
|
||||
"""Test args schema inference when the tool argument is a lambda function."""
|
||||
tool = Tool(
|
||||
name="tool",
|
||||
description="A tool",
|
||||
func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore
|
||||
)
|
||||
assert tool.args_schema is None
|
||||
expected_args = {
|
||||
"tool_input": {"title": "Tool Input"},
|
||||
"other_arg": {"title": "Other Arg"},
|
||||
}
|
||||
assert tool.args == expected_args
|
||||
|
||||
|
||||
def test_tool_partial_function_args_schema() -> None:
|
||||
"""Test args schema inference when the tool argument is a partial function."""
|
||||
|
||||
def func(tool_input: str, other_arg: str) -> str:
|
||||
return tool_input + other_arg
|
||||
|
||||
with pytest.raises(pydantic.error_wrappers.ValidationError):
|
||||
# We don't yet support args_schema inference for partial functions
|
||||
# so want to make sure we proactively raise an error
|
||||
Tool(
|
||||
name="tool",
|
||||
description="A tool",
|
||||
func=partial(func, other_arg="foo"),
|
||||
)
|
||||
|
||||
|
||||
def test_empty_args_decorator() -> None:
|
||||
"""Test inferred schema of decorated fn with no args."""
|
||||
|
||||
@tool
|
||||
def empty_tool_input() -> str:
|
||||
"""Return a constant."""
|
||||
return "the empty result"
|
||||
|
||||
assert isinstance(empty_tool_input, Tool)
|
||||
assert empty_tool_input.name == "empty_tool_input"
|
||||
assert empty_tool_input.args == {}
|
||||
assert empty_tool_input.run({}) == "the empty result"
|
||||
|
||||
|
||||
def test_named_tool_decorator() -> None:
|
||||
"""Test functionality when arguments are provided as input to decorator."""
|
||||
|
||||
@@ -304,32 +94,17 @@ def test_unnamed_tool_decorator_return_direct() -> None:
|
||||
assert search_api.return_direct
|
||||
|
||||
|
||||
def test_tool_with_kwargs() -> None:
|
||||
"""Test functionality when only return direct is provided."""
|
||||
def test_base_tool_decorator_multiple_args() -> None:
|
||||
"""Test the schema that's generated is still a simple string."""
|
||||
|
||||
@tool(return_direct=True)
|
||||
def search_api(
|
||||
arg_1: float,
|
||||
ping: str = "hi",
|
||||
) -> str:
|
||||
def some_tool(query: str, foo: int = 3, bar: Optional[dict] = None) -> str:
|
||||
"""Search the API for the query."""
|
||||
return f"arg_1={arg_1}, ping={ping}"
|
||||
return f"{query} {foo} {bar}"
|
||||
|
||||
assert isinstance(search_api, Tool)
|
||||
result = search_api.run(
|
||||
tool_input={
|
||||
"arg_1": 3.2,
|
||||
"ping": "pong",
|
||||
}
|
||||
)
|
||||
assert result == "arg_1=3.2, ping=pong"
|
||||
|
||||
result = search_api.run(
|
||||
tool_input={
|
||||
"arg_1": 3.2,
|
||||
}
|
||||
)
|
||||
assert result == "arg_1=3.2, ping=hi"
|
||||
assert isinstance(some_tool, Tool)
|
||||
assert some_tool.name == "some_tool"
|
||||
assert some_tool.run("foo") == "foo 3 None"
|
||||
|
||||
|
||||
def test_missing_docstring() -> None:
|
||||
|
||||
188
tests/unit_tests/tools/test_structured.py
Normal file
188
tests/unit_tests/tools/test_structured.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
import pydantic
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.tools.structured import (
|
||||
BaseStructuredTool,
|
||||
StructuredTool,
|
||||
structured_tool,
|
||||
)
|
||||
|
||||
|
||||
class _MockSchema(BaseModel):
|
||||
arg1: int
|
||||
arg2: bool
|
||||
arg3: Optional[dict] = None
|
||||
|
||||
|
||||
class _MockStructuredTool(BaseStructuredTool):
|
||||
name = "structured_api"
|
||||
args_schema: Type[BaseModel] = _MockSchema
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, tool_input: _MockSchema) -> str:
|
||||
return f"{tool_input.arg1} {tool_input.arg2} {tool_input.arg3}"
|
||||
|
||||
async def _arun(self, tool_input: _MockSchema) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_structured_args() -> None:
|
||||
"""Test functionality with structured arguments."""
|
||||
structured_api = _MockStructuredTool()
|
||||
assert isinstance(structured_api, BaseStructuredTool)
|
||||
assert structured_api.name == "structured_api"
|
||||
expected_result = "1 True {'foo': 'bar'}"
|
||||
args = {"arg1": 1, "arg2": True, "arg3": {"foo": "bar"}}
|
||||
assert structured_api.run(args) == expected_result
|
||||
|
||||
|
||||
def test_decorator_with_specified_schema() -> None:
|
||||
"""Test that manually specified schemata are passed through to the tool."""
|
||||
|
||||
@structured_tool(args_schema=_MockSchema)
|
||||
def tool_func(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
"""Return the arguments directly."""
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
assert isinstance(tool_func, StructuredTool)
|
||||
assert tool_func.args_schema == _MockSchema
|
||||
|
||||
|
||||
def test_decorated_function_schema_equivalent() -> None:
|
||||
"""Test that a BaseTool without a schema meets expectations."""
|
||||
|
||||
@structured_tool
|
||||
def structured_tool_input(
|
||||
arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
"""Return the arguments directly."""
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
assert isinstance(structured_tool_input, StructuredTool)
|
||||
assert (
|
||||
structured_tool_input.args_schema.schema()["properties"]
|
||||
== _MockSchema.schema()["properties"]
|
||||
== structured_tool_input.args
|
||||
)
|
||||
|
||||
|
||||
def test_tool_lambda_multi_args_schema() -> None:
|
||||
"""Test args schema inference when the tool argument is a lambda function."""
|
||||
tool = StructuredTool.from_function(
|
||||
func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore
|
||||
name="tool",
|
||||
description="A tool",
|
||||
)
|
||||
assert set(tool.args_schema.schema()["properties"]) == {"tool_input", "other_arg"}
|
||||
expected_args = {
|
||||
"tool_input": {"title": "Tool Input"},
|
||||
"other_arg": {"title": "Other Arg"},
|
||||
}
|
||||
assert tool.args == expected_args
|
||||
|
||||
|
||||
def test_tool_partial_function_args_schema() -> None:
|
||||
"""Test args schema inference when the tool argument is a partial function."""
|
||||
|
||||
def func(tool_input: str, other_arg: str) -> str:
|
||||
return tool_input + other_arg
|
||||
|
||||
with pytest.raises(pydantic.error_wrappers.ValidationError):
|
||||
# We don't yet support args_schema inference for partial functions
|
||||
# so want to make sure we proactively raise an error
|
||||
StructuredTool(
|
||||
name="tool",
|
||||
description="A tool",
|
||||
func=partial(func, other_arg="foo"),
|
||||
)
|
||||
|
||||
|
||||
def test_tool_with_kwargs() -> None:
|
||||
"""Test functionality when only return direct is provided."""
|
||||
|
||||
@structured_tool(return_direct=True)
|
||||
def search_api(
|
||||
arg_1: float,
|
||||
ping: str = "hi",
|
||||
) -> str:
|
||||
"""Search the API for the query."""
|
||||
return f"arg_1={arg_1}, ping={ping}"
|
||||
|
||||
assert isinstance(search_api, StructuredTool)
|
||||
result = search_api.run(
|
||||
tool_input={
|
||||
"arg_1": 3.2,
|
||||
"ping": "pong",
|
||||
}
|
||||
)
|
||||
assert result == "arg_1=3.2, ping=pong"
|
||||
|
||||
result = search_api.run(
|
||||
tool_input={
|
||||
"arg_1": 3.2,
|
||||
}
|
||||
)
|
||||
assert result == "arg_1=3.2, ping=hi"
|
||||
|
||||
|
||||
def test_empty_args_decorator() -> None:
|
||||
"""Test inferred schema of decorated fn with no args."""
|
||||
|
||||
@structured_tool
|
||||
def empty_tool_input() -> str:
|
||||
"""Return a constant."""
|
||||
return "the empty result"
|
||||
|
||||
assert isinstance(empty_tool_input, StructuredTool)
|
||||
assert empty_tool_input.name == "empty_tool_input"
|
||||
assert empty_tool_input.args == {}
|
||||
assert empty_tool_input.run({}) == "the empty result"
|
||||
|
||||
|
||||
def test_nested_pydantic_args() -> None:
|
||||
"""Test inferred schema when args are nested pydantic models."""
|
||||
# This is a pattern that is common with FastAPI methods.
|
||||
# If we only parse a dict input but pass the dict
|
||||
# to the function, we are limited only to primitive types
|
||||
# in general.
|
||||
|
||||
class SomeNestedInput(BaseModel):
|
||||
arg2: str
|
||||
|
||||
class SomeInput(BaseModel):
|
||||
arg1: int
|
||||
arg2: SomeNestedInput
|
||||
|
||||
@structured_tool
|
||||
def nested_tool(some_input: SomeInput) -> dict:
|
||||
"""Return a constant."""
|
||||
return some_input.dict()
|
||||
|
||||
assert isinstance(nested_tool, StructuredTool)
|
||||
assert nested_tool.name == "nested_tool"
|
||||
input_ = {"some_input": {"arg1": 1, "arg2": {"arg2": "foo"}}}
|
||||
assert nested_tool.run(input_) == input_["some_input"]
|
||||
|
||||
|
||||
def test_warning_on_args_kwargs(caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""Test inferred schema when args are nested pydantic models."""
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
|
||||
@structured_tool
|
||||
def anything_goes(*foo: Any, **bar: Any) -> str:
|
||||
"""Return a constant."""
|
||||
return str(foo) + "|" + str(bar)
|
||||
|
||||
# Check if the expected warning message was logged
|
||||
assert any(
|
||||
"anything_goes uses *args" in record.message for record in caplog.records
|
||||
)
|
||||
assert any(
|
||||
"anything_goes uses **kwargs" in record.message for record in caplog.records
|
||||
)
|
||||
Reference in New Issue
Block a user