diff --git a/docs/modules/agents/tools/custom_tools.ipynb b/docs/modules/agents/tools/custom_tools.ipynb index f69e2224b96..4f57597e07a 100644 --- a/docs/modules/agents/tools/custom_tools.ipynb +++ b/docs/modules/agents/tools/custom_tools.ipynb @@ -9,28 +9,29 @@ "\n", "When constructing your own agent, you will need to provide it with a list of Tools that it can use. Besides the actual function that is called, the Tool consists of several components:\n", "\n", - "- name (str), is required\n", - "- description (str), is optional\n", + "- name (str), is required and must be unique within a set of tools provided to an agent\n", + "- description (str), is optional but recommended, as it is used by an agent to determine tool use\n", "- return_direct (bool), defaults to False\n", "\n", - "The function that should be called when the tool is selected should take as input a single string and return a single string.\n", + "The function that should be called when the tool is selected should return a single string.\n", "\n", "There are two ways to define a tool, we will cover both in the example below." ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "1aaba18c", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "# Import things that are needed generically\n", - "from langchain.agents import initialize_agent, Tool\n", - "from langchain.agents import AgentType\n", - "from langchain.tools import BaseTool\n", - "from langchain.llms import OpenAI\n", - "from langchain import LLMMathChain, SerpAPIWrapper" + "from langchain import LLMMathChain, SerpAPIWrapper\n", + "from langchain.agents import AgentType, Tool, initialize_agent, tool\n", + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.tools import BaseTool" ] }, { @@ -43,12 +44,14 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "id": "36ed392e", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ - "llm = OpenAI(temperature=0)" + "llm = ChatOpenAI(temperature=0)" ] }, { @@ -74,7 +77,9 @@ "cell_type": "code", "execution_count": 3, "id": "56ff7670", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "# Load the tool configs that are needed.\n", @@ -98,7 +103,9 @@ "cell_type": "code", "execution_count": 4, "id": "5b93047d", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "# Construct the agent. We will use the default agent type here.\n", @@ -110,7 +117,9 @@ "cell_type": "code", "execution_count": 5, "id": "6f96a891", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [ { "name": "stdout", @@ -119,29 +128,24 @@ "\n", "\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m I need to find out who Leo DiCaprio's girlfriend is and then calculate her age raised to the 0.43 power.\n", + "\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age\n", "Action: Search\n", - "Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mCamila Morrone\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now need to calculate her age raised to the 0.43 power\n", + "Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mI draw the lime at going to get a Mohawk, though.\" DiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years. He's since been linked to another famous supermodel – Gigi Hadid.\u001b[0m\u001b[32;1m\u001b[1;3mI need to find out Gigi Hadid's age\n", + "Action: Search\n", + "Action Input: \"Gigi Hadid age\"\u001b[0m\u001b[36;1m\u001b[1;3m27 years\u001b[0m\u001b[32;1m\u001b[1;3mI need to calculate her age raised to the 0.43 power\n", "Action: Calculator\n", - "Action Input: 22^0.43\u001b[0m\n", + "Action Input: 27^(0.43)\u001b[0m\n", "\n", "\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n", - "22^0.43\u001b[32;1m\u001b[1;3m\n", - "```python\n", - "import math\n", - "print(math.pow(22, 0.43))\n", + "27^(0.43)\u001b[32;1m\u001b[1;3m```text\n", + "27**(0.43)\n", "```\n", + "...numexpr.evaluate(\"27**(0.43)\")...\n", "\u001b[0m\n", - "Answer: \u001b[33;1m\u001b[1;3m3.777824273683966\n", - "\u001b[0m\n", + "Answer: \u001b[33;1m\u001b[1;3m4.125593352125936\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n", - "\n", - "Observation: \u001b[33;1m\u001b[1;3mAnswer: 3.777824273683966\n", - "\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: Camila Morrone's age raised to the 0.43 power is 3.777824273683966.\u001b[0m\n", + "\u001b[33;1m\u001b[1;3mAnswer: 4.125593352125936\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer\n", + "Final Answer: 4.125593352125936\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -149,7 +153,7 @@ { "data": { "text/plain": [ - "\"Camila Morrone's age raised to the 0.43 power is 3.777824273683966.\"" + "'4.125593352125936'" ] }, "execution_count": 5, @@ -171,9 +175,11 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "id": "c58a7c40", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "class CustomSearchTool(BaseTool):\n", @@ -203,9 +209,11 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "id": "3318a46f", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "tools = [CustomSearchTool(), CustomCalculatorTool()]" @@ -213,9 +221,11 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 8, "id": "ee2d0f3a", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)" @@ -223,9 +233,11 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "id": "6a2cebbf", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [ { "name": "stdout", @@ -234,29 +246,24 @@ "\n", "\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m I need to find out who Leo DiCaprio's girlfriend is and then calculate her age raised to the 0.43 power.\n", + "\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age\n", "Action: Search\n", - "Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mCamila Morrone\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now need to calculate her age raised to the 0.43 power\n", + "Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mI draw the lime at going to get a Mohawk, though.\" DiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years. He's since been linked to another famous supermodel – Gigi Hadid.\u001b[0m\u001b[32;1m\u001b[1;3mI now know Leo DiCaprio's girlfriend's name and that he's currently linked to Gigi Hadid. I need to find out Camila Morrone's age.\n", + "Action: Search\n", + "Action Input: \"Camila Morrone age\"\u001b[0m\u001b[36;1m\u001b[1;3m25 years\u001b[0m\u001b[32;1m\u001b[1;3mI have Camila Morrone's age. I need to calculate her age raised to the 0.43 power.\n", "Action: Calculator\n", - "Action Input: 22^0.43\u001b[0m\n", + "Action Input: 25^(0.43)\u001b[0m\n", "\n", "\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n", - "22^0.43\u001b[32;1m\u001b[1;3m\n", - "```python\n", - "import math\n", - "print(math.pow(22, 0.43))\n", + "25^(0.43)\u001b[32;1m\u001b[1;3m```text\n", + "25**(0.43)\n", "```\n", + "...numexpr.evaluate(\"25**(0.43)\")...\n", "\u001b[0m\n", - "Answer: \u001b[33;1m\u001b[1;3m3.777824273683966\n", - "\u001b[0m\n", + "Answer: \u001b[33;1m\u001b[1;3m3.991298452658078\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n", - "\n", - "Observation: \u001b[33;1m\u001b[1;3mAnswer: 3.777824273683966\n", - "\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: Camila Morrone's age raised to the 0.43 power is 3.777824273683966.\u001b[0m\n", + "\u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\u001b[0m\u001b[32;1m\u001b[1;3mI now know the answer to the original question.\n", + "Final Answer: Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -264,10 +271,10 @@ { "data": { "text/plain": [ - "\"Camila Morrone's age raised to the 0.43 power is 3.777824273683966.\"" + "\"Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\"" ] }, - "execution_count": 11, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -288,9 +295,11 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 10, "id": "8f15307d", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "from langchain.agents import tool\n", @@ -298,22 +307,24 @@ "@tool\n", "def search_api(query: str) -> str:\n", " \"\"\"Searches the API for the query.\"\"\"\n", - " return \"Results\"" + " return f\"Results for query {query}\"" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 11, "id": "0a23b91b", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [ { "data": { "text/plain": [ - "Tool(name='search_api', description='search_api(query: str) -> str - Searches the API for the query.', return_direct=False, verbose=False, callback_manager=, func=, coroutine=None)" + "Tool(name='search_api', description='search_api(query: str) -> str - Searches the API for the query.', args_schema=, return_direct=False, verbose=False, callback_manager=, func=, coroutine=None)" ] }, - "execution_count": 5, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -332,9 +343,11 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 12, "id": "28cdf04d", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "@tool(\"search\", return_direct=True)\n", @@ -345,17 +358,17 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 13, "id": "1085a4bd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "Tool(name='search', description='search(query: str) -> str - Searches the API for the query.', return_direct=True, verbose=False, callback_manager=, func=, coroutine=None)" + "Tool(name='search', description='search(query: str) -> str - Searches the API for the query.', args_schema=, return_direct=True, verbose=False, callback_manager=, func=, coroutine=None)" ] }, - "execution_count": 7, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -376,7 +389,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 14, "id": "79213f40", "metadata": {}, "outputs": [], @@ -386,7 +399,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 15, "id": "e1067dcb", "metadata": {}, "outputs": [], @@ -396,7 +409,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 16, "id": "6c66ffe8", "metadata": {}, "outputs": [], @@ -406,7 +419,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 17, "id": "f45b5bc3", "metadata": {}, "outputs": [], @@ -416,7 +429,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 18, "id": "565e2b9b", "metadata": {}, "outputs": [ @@ -427,21 +440,12 @@ "\n", "\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m I need to find out who Leo DiCaprio's girlfriend is and then calculate her age raised to the 0.43 power.\n", + "\u001b[32;1m\u001b[1;3mI need to find out Leo DiCaprio's girlfriend's name and her age.\n", "Action: Google Search\n", - "Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mCamila Morrone\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to find out Camila Morrone's age\n", - "Action: Google Search\n", - "Action Input: \"Camila Morrone age\"\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3m25 years\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 25 raised to the 0.43 power\n", + "Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\u001b[36;1m\u001b[1;3mI draw the lime at going to get a Mohawk, though.\" DiCaprio broke up with girlfriend Camila Morrone, 25, in the summer of 2022, after dating for four years. He's since been linked to another famous supermodel – Gigi Hadid.\u001b[0m\u001b[32;1m\u001b[1;3mNow I need to find out Camila Morrone's current age.\n", "Action: Calculator\n", - "Action Input: 25^0.43\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\n", - "\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: Camila Morrone is Leo DiCaprio's girlfriend and her current age raised to the 0.43 power is 3.991298452658078.\u001b[0m\n", + "Action Input: 25^0.43\u001b[0m\u001b[33;1m\u001b[1;3mAnswer: 3.991298452658078\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer.\n", + "Final Answer: Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -449,10 +453,10 @@ { "data": { "text/plain": [ - "\"Camila Morrone is Leo DiCaprio's girlfriend and her current age raised to the 0.43 power is 3.991298452658078.\"" + "\"Camila Morrone's current age raised to the 0.43 power is approximately 3.99.\"" ] }, - "execution_count": 12, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -478,7 +482,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 19, "id": "3450512e", "metadata": {}, "outputs": [], @@ -507,7 +511,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 20, "id": "4b9a7849", "metadata": {}, "outputs": [ @@ -520,9 +524,7 @@ "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[32;1m\u001b[1;3m I should use a music search engine to find the answer\n", "Action: Music Search\n", - "Action Input: most famous song of christmas\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3m'All I Want For Christmas Is You' by Mariah Carey.\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", + "Action Input: most famous song of christmas\u001b[0m\u001b[33;1m\u001b[1;3m'All I Want For Christmas Is You' by Mariah Carey.\u001b[0m\u001b[32;1m\u001b[1;3m I now know the final answer\n", "Final Answer: 'All I Want For Christmas Is You' by Mariah Carey.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" @@ -534,7 +536,7 @@ "\"'All I Want For Christmas Is You' by Mariah Carey.\"" ] }, - "execution_count": 14, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -554,7 +556,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 21, "id": "3bb6185f", "metadata": {}, "outputs": [], @@ -572,7 +574,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 22, "id": "113ddb84", "metadata": {}, "outputs": [], @@ -583,9 +585,11 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 23, "id": "582439a6", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [ { "name": "stdout", @@ -596,9 +600,7 @@ "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[32;1m\u001b[1;3m I need to calculate this\n", "Action: Calculator\n", - "Action Input: 2**.12\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.2599210498948732\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m\u001b[0m\n", + "Action Input: 2**.12\u001b[0m\u001b[36;1m\u001b[1;3mAnswer: 1.086734862526058\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -606,10 +608,10 @@ { "data": { "text/plain": [ - "'Answer: 1.2599210498948732'" + "'Answer: 1.086734862526058'" ] }, - "execution_count": 17, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } @@ -618,10 +620,149 @@ "agent.run(\"whats 2**.12\")" ] }, + { + "cell_type": "markdown", + "id": "8aa3c353-bd89-467c-9c27-b83a90cd4daa", + "metadata": {}, + "source": [ + "## Multi-argument tools\n", + "\n", + "Many functions expect structured inputs. These can also be supported using the Tool decorator or by directly subclassing `BaseTool`! We have to modify the LLM's OutputParser to map its string output to a dictionary to pass to the action, however." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "537bc628", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from typing import Optional, Union\n", + "\n", + "@tool\n", + "def custom_search(k: int, query: str, other_arg: Optional[str] = None):\n", + " \"\"\"The custom search function.\"\"\"\n", + " return f\"Here are the results for the custom search: k={k}, query={query}, other_arg={other_arg}\"" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "d5c992cf-776a-40cd-a6c4-e7cf65ea709e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import re\n", + "from langchain.schema import (\n", + " AgentAction,\n", + " AgentFinish,\n", + ")\n", + "from langchain.agents import AgentOutputParser\n", + "\n", + "# We will add a custom parser to map the arguments to a dictionary\n", + "class CustomOutputParser(AgentOutputParser):\n", + " \n", + " def parse_tool_input(self, action_input: str) -> dict:\n", + " # Regex pattern to match arguments and their values\n", + " pattern = r\"(\\w+)\\s*=\\s*(None|\\\"[^\\\"]*\\\"|\\d+)\"\n", + " matches = re.findall(pattern, action_input)\n", + " \n", + " if not matches:\n", + " raise ValueError(f\"Could not parse action input: `{action_input}`\")\n", + "\n", + " # Create a dictionary with the parsed arguments and their values\n", + " parsed_input = {}\n", + " for arg, value in matches:\n", + " if value == \"None\":\n", + " parsed_value = None\n", + " elif value.isdigit():\n", + " parsed_value = int(value)\n", + " else:\n", + " parsed_value = value.strip('\"')\n", + " parsed_input[arg] = parsed_value\n", + "\n", + " return parsed_input\n", + " \n", + " def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:\n", + " # Check if agent should finish\n", + " if \"Final Answer:\" in llm_output:\n", + " return AgentFinish(\n", + " # Return values is generally always a dictionary with a single `output` key\n", + " # It is not recommended to try anything else at the moment :)\n", + " return_values={\"output\": llm_output.split(\"Final Answer:\")[-1].strip()},\n", + " log=llm_output,\n", + " )\n", + " # Parse out the action and action input\n", + " regex = r\"Action\\s*\\d*\\s*:(.*?)\\nAction\\s*\\d*\\s*Input\\s*\\d*\\s*:[\\s]*(.*)\"\n", + " match = re.search(regex, llm_output, re.DOTALL)\n", + " if not match:\n", + " raise ValueError(f\"Could not parse LLM output: `{llm_output}`\")\n", + " action = match.group(1).strip()\n", + " action_input = match.group(2)\n", + " tool_input = self.parse_tool_input(action_input)\n", + " # Return the action and action \n", + " return AgentAction(tool=action, tool_input=tool_input, log=llm_output)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "68269547-1482-4138-a6ea-58f00b4a9548", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "llm = OpenAI(temperature=0)\n", + "agent = initialize_agent([custom_search], llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, agent_kwargs={\"output_parser\": CustomOutputParser()})" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "0947835a-691c-4f51-b8f4-6744e0e48ab1", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m I need to use a search function to find the answer\n", + "Action: custom_search\n", + "Action Input: k=1, query=\"me\"\u001b[0m\u001b[36;1m\u001b[1;3mHere are the results for the custom search: k=1, query=me, other_arg=None\u001b[0m\u001b[32;1m\u001b[1;3m I now know the final answer\n", + "Final Answer: The results of the custom search for k=1, query=me, other_arg=None.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'The results of the custom search for k=1, query=me, other_arg=None.'" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.run(\"Search for me and tell me whatever it says\")" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "537bc628", + "id": "caf39c66-102b-42c1-baf2-777a49886ce4", "metadata": {}, "outputs": [], "source": [] @@ -643,7 +784,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.11.2" }, "vscode": { "interpreter": { diff --git a/langchain/agents/tools.py b/langchain/agents/tools.py index 5d7421bf9e9..9ae8f698edc 100644 --- a/langchain/agents/tools.py +++ b/langchain/agents/tools.py @@ -1,6 +1,8 @@ """Interface for tools.""" from inspect import signature -from typing import Any, Awaitable, Callable, Optional, Union +from typing import Any, Awaitable, Callable, Optional, Type, Union + +from pydantic import BaseModel from langchain.tools.base import BaseTool, create_args_schema_model_from_signature @@ -14,6 +16,15 @@ class Tool(BaseTool): coroutine: Optional[Callable[..., Awaitable[str]]] = None """The asynchronous version of the function.""" + @property + def args(self) -> Type[BaseModel]: + """Generate an input pydantic model.""" + if self.args_schema is not None: + return self.args_schema + # Infer the schema directly from the function to add more structured + # arguments. + return create_args_schema_model_from_signature(self.func) + def _run(self, *args: Any, **kwargs: Any) -> str: """Use the tool.""" return self.func(*args, **kwargs) diff --git a/langchain/callbacks/comet_ml_callback.py b/langchain/callbacks/comet_ml_callback.py index 6f061f14aa4..f057cea84f0 100644 --- a/langchain/callbacks/comet_ml_callback.py +++ b/langchain/callbacks/comet_ml_callback.py @@ -371,7 +371,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self.starts += 1 tool = action.tool - tool_input = action.tool_input + tool_input = str(action.tool_input) log = action.log resp = self._init_resp() diff --git a/langchain/schema.py b/langchain/schema.py index 8678b1a89f7..a2b709f138f 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -2,7 +2,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, List, NamedTuple, Optional, TypeVar +from typing import Any, Dict, Generic, List, NamedTuple, Optional, TypeVar, Union from pydantic import BaseModel, Extra, Field, root_validator @@ -31,7 +31,7 @@ class AgentAction(NamedTuple): """Agent's action to take.""" tool: str - tool_input: str + tool_input: Union[str, dict] log: str diff --git a/tests/unit_tests/agents/test_mrkl.py b/tests/unit_tests/agents/test_mrkl.py index d88fbf02866..8cf4f6f8d05 100644 --- a/tests/unit_tests/agents/test_mrkl.py +++ b/tests/unit_tests/agents/test_mrkl.py @@ -16,7 +16,7 @@ from tests.unit_tests.llms.fake_llm import FakeLLM def get_action_and_input(text: str) -> Tuple[str, str]: output = MRKLOutputParser().parse(text) if isinstance(output, AgentAction): - return output.tool, output.tool_input + return output.tool, str(output.tool_input) else: return "Final Answer", output.return_values["output"] diff --git a/tests/unit_tests/agents/test_tools.py b/tests/unit_tests/agents/test_tools.py index 5f3ddb63146..09dda518fdc 100644 --- a/tests/unit_tests/agents/test_tools.py +++ b/tests/unit_tests/agents/test_tools.py @@ -120,17 +120,26 @@ def test_unnamed_tool_decorator_return_direct() -> None: assert search_api.return_direct -def test_tool_with_kwarg() -> None: +def test_tool_with_kwargs() -> None: """Test functionality when only return direct is provided.""" @tool(return_direct=True) - def search_api(*args: Any, **kwargs: Any) -> str: + def search_api( + arg_1: float, *args: Any, ping: Optional[str] = None, **kwargs: Any + ) -> str: """Search the API for the query.""" - return f"foo={args}, kwargs={kwargs}" + return f"arg_1={arg_1}, foo={args}, ping={ping}, kwargs={kwargs}" assert isinstance(search_api, Tool) - result = search_api.run(tool_input={"args": "fam", "kwargs": {"bar": "baz"}}) - assert result == "foo=('fam',), kwargs={'bar': 'baz'}" + result = search_api.run( + tool_input={ + "arg_1": 3.2, + "args": "fam", + "kwargs": {"bar": "baz"}, + "ping": "pong", + } + ) + assert result == "arg_1=3.2, foo=('fam',), ping=pong, kwargs={'bar': 'baz'}" def test_missing_docstring() -> None: