(WIP) agents (#171)

This commit is contained in:
Harrison Chase
2022-11-22 06:16:26 -08:00
committed by GitHub
parent 22bd12a097
commit d3a7429f61
46 changed files with 1111 additions and 733 deletions

View File

@@ -0,0 +1 @@
"""Test agent functionality."""

View File

@@ -2,8 +2,9 @@
import pytest
from langchain.chains.mrkl.base import ChainConfig, MRKLChain, get_action_and_input
from langchain.chains.mrkl.prompt import BASE_TEMPLATE
from langchain.agents.mrkl.base import ZeroShotAgent, get_action_and_input
from langchain.agents.mrkl.prompt import BASE_TEMPLATE
from langchain.agents.tools import Tool
from langchain.prompts import PromptTemplate
from tests.unit_tests.llms.fake_llm import FakeLLM
@@ -29,7 +30,7 @@ def test_get_final_answer() -> None:
"Final Answer: 1994"
)
action, action_input = get_action_and_input(llm_output)
assert action == "Final Answer: "
assert action == "Final Answer"
assert action_input == "1994"
@@ -52,19 +53,15 @@ def test_bad_action_line() -> None:
def test_from_chains() -> None:
"""Test initializing from chains."""
chain_configs = [
ChainConfig(
action_name="foo", action=lambda x: "foo", action_description="foobar1"
),
ChainConfig(
action_name="bar", action=lambda x: "bar", action_description="foobar2"
),
Tool(name="foo", func=lambda x: "foo", description="foobar1"),
Tool(name="bar", func=lambda x: "bar", description="foobar2"),
]
mrkl_chain = MRKLChain.from_chains(FakeLLM(), chain_configs)
agent = ZeroShotAgent.from_llm_and_tools(FakeLLM(), chain_configs)
expected_tools_prompt = "foo: foobar1\nbar: foobar2"
expected_tool_names = "foo, bar"
expected_template = BASE_TEMPLATE.format(
tools=expected_tools_prompt, tool_names=expected_tool_names
)
prompt = mrkl_chain.prompt
prompt = agent.llm_chain.prompt
assert isinstance(prompt, PromptTemplate)
assert prompt.template == expected_template

View File

@@ -4,8 +4,8 @@ from typing import Any, List, Mapping, Optional, Union
import pytest
from langchain.chains.llm import LLMChain
from langchain.chains.react.base import ReActChain, predict_until_observation
from langchain.agents.react.base import ReActChain, ReActDocstoreAgent
from langchain.agents.tools import Tool
from langchain.docstore.base import Docstore
from langchain.docstore.document import Document
from langchain.llms.base import LLM
@@ -51,33 +51,32 @@ class FakeDocstore(Docstore):
def test_predict_until_observation_normal() -> None:
"""Test predict_until_observation when observation is made normally."""
outputs = ["foo\nAction 1: search[foo]"]
outputs = ["foo\nAction 1: Search[foo]"]
fake_llm = FakeListLLM(outputs)
fake_llm_chain = LLMChain(llm=fake_llm, prompt=_FAKE_PROMPT)
ret_text, action, directive = predict_until_observation(fake_llm_chain, "", 1)
assert ret_text == outputs[0]
assert action == "search"
assert directive == "foo"
tools = [
Tool("Search", lambda x: x),
Tool("Lookup", lambda x: x),
]
agent = ReActDocstoreAgent.from_llm_and_tools(fake_llm, tools)
output = agent.get_action("")
assert output.log == outputs[0]
assert output.tool == "Search"
assert output.tool_input == "foo"
def test_predict_until_observation_repeat() -> None:
"""Test when no action is generated initially."""
outputs = ["foo", " search[foo]"]
outputs = ["foo", " Search[foo]"]
fake_llm = FakeListLLM(outputs)
fake_llm_chain = LLMChain(llm=fake_llm, prompt=_FAKE_PROMPT)
ret_text, action, directive = predict_until_observation(fake_llm_chain, "", 1)
assert ret_text == "foo\nAction 1: search[foo]"
assert action == "search"
assert directive == "foo"
def test_predict_until_observation_error() -> None:
"""Test handling of generation of text that cannot be parsed."""
outputs = ["foo\nAction 1: foo"]
fake_llm = FakeListLLM(outputs)
fake_llm_chain = LLMChain(llm=fake_llm, prompt=_FAKE_PROMPT)
with pytest.raises(ValueError):
predict_until_observation(fake_llm_chain, "", 1)
tools = [
Tool("Search", lambda x: x),
Tool("Lookup", lambda x: x),
]
agent = ReActDocstoreAgent.from_llm_and_tools(fake_llm, tools)
output = agent.get_action("")
assert output.log == "foo\nAction 1: Search[foo]"
assert output.tool == "Search"
assert output.tool_input == "foo"
def test_react_chain() -> None:
@@ -89,9 +88,8 @@ def test_react_chain() -> None:
]
fake_llm = FakeListLLM(responses)
react_chain = ReActChain(llm=fake_llm, docstore=FakeDocstore())
inputs = {"question": "when was langchain made"}
output = react_chain(inputs)
assert output["answer"] == "2022"
output = react_chain.run("when was langchain made")
assert output == "2022"
def test_react_chain_bad_action() -> None:
@@ -101,5 +99,5 @@ def test_react_chain_bad_action() -> None:
]
fake_llm = FakeListLLM(responses)
react_chain = ReActChain(llm=fake_llm, docstore=FakeDocstore())
with pytest.raises(ValueError):
with pytest.raises(KeyError):
react_chain.run("when was langchain made")