mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 07:09:31 +00:00
add return_direct flag to tool (#537)
adds a return_direct flag to tools, which just returns the tool output as the final output
This commit is contained in:
parent
1f248c47f3
commit
4974f49bb7
@ -44,7 +44,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 2,
|
||||
"id": "36ed392e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -63,7 +63,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 3,
|
||||
"id": "56ff7670",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -94,7 +94,6 @@
|
||||
"source": [
|
||||
"# Construct the agent. We will use the default agent type here.\n",
|
||||
"# See documentation for a full list of options.\n",
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"agent = initialize_agent(tools, llm, agent=\"zero-shot-react-description\", verbose=True)"
|
||||
]
|
||||
},
|
||||
@ -248,10 +247,86 @@
|
||||
"agent.run(\"Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a8d95999",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using tools to return directly\n",
|
||||
"\n",
|
||||
"Often, it can be desirable to have a tool output returned directly to the user, if it's called. You can do this easily with LangChain by setting the `return_direct` flag for a tool to be True."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "c8f65640",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm_math_chain = LLMMathChain(llm=llm)\n",
|
||||
"tools = [\n",
|
||||
" Tool(\n",
|
||||
" name=\"Calculator\",\n",
|
||||
" func=llm_math_chain.run,\n",
|
||||
" description=\"useful for when you need to answer questions about math\",\n",
|
||||
" return_direct=True\n",
|
||||
" )\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "4bd1c4bb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"agent = initialize_agent(tools, llm, agent=\"zero-shot-react-description\", verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "52fe0594",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3m I need to calculate this\n",
|
||||
"Action: Calculator\n",
|
||||
"Action Input: 2**.12\u001b[0m\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.2599210498948732\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3m\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'Answer: 1.2599210498948732'"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agent.run(\"whats 2**.12\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3450512e",
|
||||
"id": "7cc1b875",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
@ -259,7 +334,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.9.0 64-bit ('llm-env')",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -273,7 +348,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.0"
|
||||
"version": "3.10.9"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
|
@ -239,12 +239,20 @@ class AgentExecutor(Chain, BaseModel):
|
||||
else:
|
||||
return iterations < self.max_iterations
|
||||
|
||||
def _return(self, output: AgentFinish, intermediate_steps: list) -> Dict[str, Any]:
|
||||
if self.verbose:
|
||||
self.callback_manager.on_agent_finish(output, color="green")
|
||||
final_output = output.return_values
|
||||
if self.return_intermediate_steps:
|
||||
final_output["intermediate_steps"] = intermediate_steps
|
||||
return final_output
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
"""Run text through and get agent response."""
|
||||
# Do any preparation necessary when receiving a new input.
|
||||
self.agent.prepare_for_new_call()
|
||||
# Construct a mapping of tool name to tool for easy lookup
|
||||
name_to_tool_map = {tool.name: tool.func for tool in self.tools}
|
||||
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
||||
# We construct a mapping from each tool to a color, used for logging.
|
||||
color_mapping = get_color_mapping(
|
||||
[tool.name for tool in self.tools], excluded_colors=["green"]
|
||||
@ -258,24 +266,20 @@ class AgentExecutor(Chain, BaseModel):
|
||||
output = self.agent.plan(intermediate_steps, **inputs)
|
||||
# If the tool chosen is the finishing tool, then we end and return.
|
||||
if isinstance(output, AgentFinish):
|
||||
if self.verbose:
|
||||
self.callback_manager.on_agent_finish(output, color="green")
|
||||
final_output = output.return_values
|
||||
if self.return_intermediate_steps:
|
||||
final_output["intermediate_steps"] = intermediate_steps
|
||||
return final_output
|
||||
return self._return(output, intermediate_steps)
|
||||
|
||||
# And then we lookup the tool
|
||||
# Otherwise we lookup the tool
|
||||
if output.tool in name_to_tool_map:
|
||||
chain = name_to_tool_map[output.tool]
|
||||
tool = name_to_tool_map[output.tool]
|
||||
if self.verbose:
|
||||
self.callback_manager.on_tool_start(
|
||||
{"name": str(chain)[:60] + "..."}, output, color="green"
|
||||
{"name": str(tool.func)[:60] + "..."}, output, color="green"
|
||||
)
|
||||
try:
|
||||
# We then call the tool on the tool input to get an observation
|
||||
observation = chain(output.tool_input)
|
||||
observation = tool.func(output.tool_input)
|
||||
color = color_mapping[output.tool]
|
||||
return_direct = tool.return_direct
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
self.callback_manager.on_tool_error(e)
|
||||
@ -287,21 +291,22 @@ class AgentExecutor(Chain, BaseModel):
|
||||
)
|
||||
observation = f"{output.tool} is not a valid tool, try another one."
|
||||
color = None
|
||||
return_direct = False
|
||||
if self.verbose:
|
||||
llm_prefix = "" if return_direct else self.agent.llm_prefix
|
||||
self.callback_manager.on_tool_end(
|
||||
observation,
|
||||
color=color,
|
||||
observation_prefix=self.agent.observation_prefix,
|
||||
llm_prefix=self.agent.llm_prefix,
|
||||
llm_prefix=llm_prefix,
|
||||
)
|
||||
intermediate_steps.append((output, observation))
|
||||
if return_direct:
|
||||
# Set the log to "" because we do not want to log it.
|
||||
output = AgentFinish({self.agent.return_values[0]: observation}, "")
|
||||
return self._return(output, intermediate_steps)
|
||||
iterations += 1
|
||||
output = self.agent.return_stopped_response(
|
||||
self.early_stopping_method, intermediate_steps, **inputs
|
||||
)
|
||||
if self.verbose:
|
||||
self.callback_manager.on_agent_finish(output, color="green")
|
||||
final_output = output.return_values
|
||||
if self.return_intermediate_steps:
|
||||
final_output["intermediate_steps"] = intermediate_steps
|
||||
return final_output
|
||||
return self._return(output, intermediate_steps)
|
||||
|
@ -10,3 +10,4 @@ class Tool:
|
||||
name: str
|
||||
func: Callable[[str], str]
|
||||
description: Optional[str] = None
|
||||
return_direct: bool = False
|
||||
|
@ -169,3 +169,24 @@ def test_agent_with_callbacks_not_verbose() -> None:
|
||||
assert handler.starts == 0
|
||||
assert handler.ends == 0
|
||||
assert handler.errors == 0
|
||||
|
||||
|
||||
def test_agent_tool_return_direct() -> None:
|
||||
"""Test agent using tools that return directly."""
|
||||
tool = "Search"
|
||||
responses = [
|
||||
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
|
||||
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
|
||||
]
|
||||
fake_llm = FakeListLLM(responses=responses)
|
||||
tools = [
|
||||
Tool("Search", lambda x: x, "Useful for searching", return_direct=True),
|
||||
]
|
||||
agent = initialize_agent(
|
||||
tools,
|
||||
fake_llm,
|
||||
agent="zero-shot-react-description",
|
||||
)
|
||||
|
||||
output = agent.run("when was langchain made")
|
||||
assert output == "misalignment"
|
||||
|
Loading…
Reference in New Issue
Block a user