From 4974f49bb7c2761c6f7b9f229869e6fea0abf7b3 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Fri, 6 Jan 2023 06:40:32 -0800 Subject: [PATCH] add return_direct flag to tool (#537) adds a return_direct flag to tools, which just returns the tool output as the final output --- .../agents/examples/custom_tools.ipynb | 87 +++++++++++++++++-- langchain/agents/agent.py | 41 +++++---- langchain/agents/tools.py | 1 + tests/unit_tests/agents/test_agent.py | 21 +++++ 4 files changed, 126 insertions(+), 24 deletions(-) diff --git a/docs/modules/agents/examples/custom_tools.ipynb b/docs/modules/agents/examples/custom_tools.ipynb index 568f9ab451b..061ae06ad8c 100644 --- a/docs/modules/agents/examples/custom_tools.ipynb +++ b/docs/modules/agents/examples/custom_tools.ipynb @@ -44,7 +44,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "id": "36ed392e", "metadata": {}, "outputs": [], @@ -63,7 +63,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "id": "56ff7670", "metadata": {}, "outputs": [], @@ -94,7 +94,6 @@ "source": [ "# Construct the agent. We will use the default agent type here.\n", "# See documentation for a full list of options.\n", - "llm = OpenAI(temperature=0)\n", "agent = initialize_agent(tools, llm, agent=\"zero-shot-react-description\", verbose=True)" ] }, @@ -248,10 +247,86 @@ "agent.run(\"Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?\")" ] }, + { + "cell_type": "markdown", + "id": "a8d95999", + "metadata": {}, + "source": [ + "## Using tools to return directly\n", + "\n", + "Often, it can be desirable to have a tool output returned directly to the user, if it's called. You can do this easily with LangChain by setting the `return_direct` flag for a tool to be True." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c8f65640", + "metadata": {}, + "outputs": [], + "source": [ + "llm_math_chain = LLMMathChain(llm=llm)\n", + "tools = [\n", + " Tool(\n", + " name=\"Calculator\",\n", + " func=llm_math_chain.run,\n", + " description=\"useful for when you need to answer questions about math\",\n", + " return_direct=True\n", + " )\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "4bd1c4bb", + "metadata": {}, + "outputs": [], + "source": [ + "llm = OpenAI(temperature=0)\n", + "agent = initialize_agent(tools, llm, agent=\"zero-shot-react-description\", verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "52fe0594", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m I need to calculate this\n", + "Action: Calculator\n", + "Action Input: 2**.12\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.2599210498948732\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'Answer: 1.2599210498948732'" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.run(\"whats 2**.12\")" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "3450512e", + "id": "7cc1b875", "metadata": {}, "outputs": [], "source": [] @@ -259,7 +334,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.9.0 64-bit ('llm-env')", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -273,7 +348,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.0" + "version": "3.10.9" }, "vscode": { "interpreter": { diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index dc669176e44..f79bc6cf246 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -239,12 +239,20 @@ class AgentExecutor(Chain, BaseModel): else: return iterations < self.max_iterations + def _return(self, output: AgentFinish, intermediate_steps: list) -> Dict[str, Any]: + if self.verbose: + self.callback_manager.on_agent_finish(output, color="green") + final_output = output.return_values + if self.return_intermediate_steps: + final_output["intermediate_steps"] = intermediate_steps + return final_output + def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: """Run text through and get agent response.""" # Do any preparation necessary when receiving a new input. self.agent.prepare_for_new_call() # Construct a mapping of tool name to tool for easy lookup - name_to_tool_map = {tool.name: tool.func for tool in self.tools} + name_to_tool_map = {tool.name: tool for tool in self.tools} # We construct a mapping from each tool to a color, used for logging. color_mapping = get_color_mapping( [tool.name for tool in self.tools], excluded_colors=["green"] @@ -258,24 +266,20 @@ class AgentExecutor(Chain, BaseModel): output = self.agent.plan(intermediate_steps, **inputs) # If the tool chosen is the finishing tool, then we end and return. if isinstance(output, AgentFinish): - if self.verbose: - self.callback_manager.on_agent_finish(output, color="green") - final_output = output.return_values - if self.return_intermediate_steps: - final_output["intermediate_steps"] = intermediate_steps - return final_output + return self._return(output, intermediate_steps) - # And then we lookup the tool + # Otherwise we lookup the tool if output.tool in name_to_tool_map: - chain = name_to_tool_map[output.tool] + tool = name_to_tool_map[output.tool] if self.verbose: self.callback_manager.on_tool_start( - {"name": str(chain)[:60] + "..."}, output, color="green" + {"name": str(tool.func)[:60] + "..."}, output, color="green" ) try: # We then call the tool on the tool input to get an observation - observation = chain(output.tool_input) + observation = tool.func(output.tool_input) color = color_mapping[output.tool] + return_direct = tool.return_direct except Exception as e: if self.verbose: self.callback_manager.on_tool_error(e) @@ -287,21 +291,22 @@ class AgentExecutor(Chain, BaseModel): ) observation = f"{output.tool} is not a valid tool, try another one." color = None + return_direct = False if self.verbose: + llm_prefix = "" if return_direct else self.agent.llm_prefix self.callback_manager.on_tool_end( observation, color=color, observation_prefix=self.agent.observation_prefix, - llm_prefix=self.agent.llm_prefix, + llm_prefix=llm_prefix, ) intermediate_steps.append((output, observation)) + if return_direct: + # Set the log to "" because we do not want to log it. + output = AgentFinish({self.agent.return_values[0]: observation}, "") + return self._return(output, intermediate_steps) iterations += 1 output = self.agent.return_stopped_response( self.early_stopping_method, intermediate_steps, **inputs ) - if self.verbose: - self.callback_manager.on_agent_finish(output, color="green") - final_output = output.return_values - if self.return_intermediate_steps: - final_output["intermediate_steps"] = intermediate_steps - return final_output + return self._return(output, intermediate_steps) diff --git a/langchain/agents/tools.py b/langchain/agents/tools.py index 5235bc3bb9b..cdcd06504ac 100644 --- a/langchain/agents/tools.py +++ b/langchain/agents/tools.py @@ -10,3 +10,4 @@ class Tool: name: str func: Callable[[str], str] description: Optional[str] = None + return_direct: bool = False diff --git a/tests/unit_tests/agents/test_agent.py b/tests/unit_tests/agents/test_agent.py index 581f61ca52a..c9461e5df14 100644 --- a/tests/unit_tests/agents/test_agent.py +++ b/tests/unit_tests/agents/test_agent.py @@ -169,3 +169,24 @@ def test_agent_with_callbacks_not_verbose() -> None: assert handler.starts == 0 assert handler.ends == 0 assert handler.errors == 0 + + +def test_agent_tool_return_direct() -> None: + """Test agent using tools that return directly.""" + tool = "Search" + responses = [ + f"FooBarBaz\nAction: {tool}\nAction Input: misalignment", + "Oh well\nAction: Final Answer\nAction Input: curses foiled again", + ] + fake_llm = FakeListLLM(responses=responses) + tools = [ + Tool("Search", lambda x: x, "Useful for searching", return_direct=True), + ] + agent = initialize_agent( + tools, + fake_llm, + agent="zero-shot-react-description", + ) + + output = agent.run("when was langchain made") + assert output == "misalignment"