Gracefully degrade when model asks for nonexistent tool (#268)

Not yet tested, but very simple change, assumption is that we're cool
with just producing a generic output when tool is not found
This commit is contained in:
John McDonnell 2022-12-06 21:52:48 -08:00 committed by GitHub
parent 2180a91196
commit 68666d6a22
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 9 deletions

View File

@ -150,12 +150,17 @@ class Agent(Chain, BaseModel, ABC):
if output.tool == self.finish_tool_name: if output.tool == self.finish_tool_name:
return {self.output_key: output.tool_input} return {self.output_key: output.tool_input}
# Otherwise we lookup the tool # Otherwise we lookup the tool
chain = name_to_tool_map[output.tool] if output.tool in name_to_tool_map:
# We then call the tool on the tool input to get an observation chain = name_to_tool_map[output.tool]
observation = chain(output.tool_input) # We then call the tool on the tool input to get an observation
observation = chain(output.tool_input)
color = color_mapping[output.tool]
else:
observation = f"{output.tool} is not a valid tool, try another one."
color = None
# We then log the observation # We then log the observation
chained_input.add(f"\n{self.observation_prefix}") chained_input.add(f"\n{self.observation_prefix}")
chained_input.add(observation, color=color_mapping[output.tool]) chained_input.add(observation, color=color)
# We then add the LLM prefix into the prompt to get the LLM to start # We then add the LLM prefix into the prompt to get the LLM to start
# thinking, and start the loop all over. # thinking, and start the loop all over.
chained_input.add(f"\n{self.llm_prefix}") chained_input.add(f"\n{self.llm_prefix}")

View File

@ -0,0 +1,45 @@
"""Unit tests for agents."""
from typing import Any, List, Mapping, Optional
from langchain.agents import Tool, initialize_agent
from langchain.llms.base import LLM
class FakeListLLM(LLM):
"""Fake LLM for testing that outputs elements of a list."""
def __init__(self, responses: List[str]):
"""Initialize with list of responses."""
self.responses = responses
self.i = -1
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Increment counter, and then return response in that index."""
self.i += 1
print(self.i)
print(self.responses)
return self.responses[self.i]
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {}
def test_agent_bad_action() -> None:
"""Test react chain when bad action given."""
bad_action_name = "BadAction"
responses = [
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
]
fake_llm = FakeListLLM(responses)
tools = [
Tool("Search", lambda x: x, "Useful for searching"),
Tool("Lookup", lambda x: x, "Useful for looking up things in a table"),
]
agent = initialize_agent(
tools, fake_llm, agent="zero-shot-react-description", verbose=True
)
output = agent.run("when was langchain made")
assert output == "curses foiled again"

View File

@ -2,8 +2,6 @@
from typing import Any, List, Mapping, Optional, Union from typing import Any, List, Mapping, Optional, Union
import pytest
from langchain.agents.react.base import ReActChain, ReActDocstoreAgent from langchain.agents.react.base import ReActChain, ReActDocstoreAgent
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.docstore.base import Docstore from langchain.docstore.base import Docstore
@ -94,10 +92,12 @@ def test_react_chain() -> None:
def test_react_chain_bad_action() -> None: def test_react_chain_bad_action() -> None:
"""Test react chain when bad action given.""" """Test react chain when bad action given."""
bad_action_name = "BadAction"
responses = [ responses = [
"I should probably search\nAction 1: BadAction[langchain]", f"I'm turning evil\nAction 1: {bad_action_name}[langchain]",
"Oh well\nAction 2: Finish[curses foiled again]",
] ]
fake_llm = FakeListLLM(responses) fake_llm = FakeListLLM(responses)
react_chain = ReActChain(llm=fake_llm, docstore=FakeDocstore()) react_chain = ReActChain(llm=fake_llm, docstore=FakeDocstore())
with pytest.raises(KeyError): output = react_chain.run("when was langchain made")
react_chain.run("when was langchain made") assert output == "curses foiled again"