mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 23:00:00 +00:00
Improving the text of the invalid tool to list the available tools. (#8767)
Description: When using a ReAct Agent with tools and no tool is found, the InvalidTool gets called. Previously it just asked for a different action, but I've found that if you list the available actions it improves the chances of getting a valid action in the next round. I've added a UnitTest for it also. @hinthornw
This commit is contained in:
parent
d9bc46186d
commit
2111ed3c75
@ -897,7 +897,10 @@ s
|
|||||||
else:
|
else:
|
||||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||||
observation = InvalidTool().run(
|
observation = InvalidTool().run(
|
||||||
agent_action.tool,
|
{
|
||||||
|
"requested_tool_name": agent_action.tool,
|
||||||
|
"available_tool_names": list(name_to_tool_map.keys()),
|
||||||
|
},
|
||||||
verbose=self.verbose,
|
verbose=self.verbose,
|
||||||
color=None,
|
color=None,
|
||||||
callbacks=run_manager.get_child() if run_manager else None,
|
callbacks=run_manager.get_child() if run_manager else None,
|
||||||
@ -992,7 +995,10 @@ s
|
|||||||
else:
|
else:
|
||||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||||
observation = await InvalidTool().arun(
|
observation = await InvalidTool().arun(
|
||||||
agent_action.tool,
|
{
|
||||||
|
"requested_tool_name": agent_action.tool,
|
||||||
|
"available_tool_names": list(name_to_tool_map.keys()),
|
||||||
|
},
|
||||||
verbose=self.verbose,
|
verbose=self.verbose,
|
||||||
color=None,
|
color=None,
|
||||||
callbacks=run_manager.get_child() if run_manager else None,
|
callbacks=run_manager.get_child() if run_manager else None,
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""Interface for tools."""
|
"""Interface for tools."""
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForToolRun,
|
AsyncCallbackManagerForToolRun,
|
||||||
@ -12,23 +12,33 @@ class InvalidTool(BaseTool):
|
|||||||
"""Tool that is run when invalid tool name is encountered by agent."""
|
"""Tool that is run when invalid tool name is encountered by agent."""
|
||||||
|
|
||||||
name = "invalid_tool"
|
name = "invalid_tool"
|
||||||
"""Name of the tool."""
|
description = "Called when tool name is invalid. Suggests valid tool names."
|
||||||
description = "Called when tool name is invalid."
|
|
||||||
"""Description of the tool."""
|
|
||||||
|
|
||||||
def _run(
|
def _run(
|
||||||
self, tool_name: str, run_manager: Optional[CallbackManagerForToolRun] = None
|
self,
|
||||||
|
requested_tool_name: str,
|
||||||
|
available_tool_names: List[str],
|
||||||
|
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Use the tool."""
|
"""Use the tool."""
|
||||||
return f"{tool_name} is not a valid tool, try another one."
|
available_tool_names_str = ", ".join([tool for tool in available_tool_names])
|
||||||
|
return (
|
||||||
|
f"{requested_tool_name} is not a valid tool, "
|
||||||
|
f"try one of [{available_tool_names_str}]."
|
||||||
|
)
|
||||||
|
|
||||||
async def _arun(
|
async def _arun(
|
||||||
self,
|
self,
|
||||||
tool_name: str,
|
requested_tool_name: str,
|
||||||
|
available_tool_names: List[str],
|
||||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Use the tool asynchronously."""
|
"""Use the tool asynchronously."""
|
||||||
return f"{tool_name} is not a valid tool, try another one."
|
available_tool_names_str = ", ".join([tool for tool in available_tool_names])
|
||||||
|
return (
|
||||||
|
f"{requested_tool_name} is not a valid tool, "
|
||||||
|
f"try one of [{available_tool_names_str}]."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["InvalidTool", "BaseTool", "tool", "Tool"]
|
__all__ = ["InvalidTool", "BaseTool", "tool", "Tool"]
|
||||||
|
@ -257,3 +257,26 @@ def test_agent_lookup_tool() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert agent.lookup_tool("Search") == tools[0]
|
assert agent.lookup_tool("Search") == tools[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_invalid_tool() -> None:
|
||||||
|
"""Test agent invalid tool and correct suggestions."""
|
||||||
|
fake_llm = FakeListLLM(responses=["FooBarBaz\nAction: Foo\nAction Input: Bar"])
|
||||||
|
tools = [
|
||||||
|
Tool(
|
||||||
|
name="Search",
|
||||||
|
func=lambda x: x,
|
||||||
|
description="Useful for searching",
|
||||||
|
return_direct=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
agent = initialize_agent(
|
||||||
|
tools=tools,
|
||||||
|
llm=fake_llm,
|
||||||
|
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||||
|
return_intermediate_steps=True,
|
||||||
|
max_iterations=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = agent("when was langchain made")
|
||||||
|
resp["intermediate_steps"][0][1] == "Foo is not a valid tool, try one of [Search]."
|
||||||
|
Loading…
Reference in New Issue
Block a user