mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-31 18:38:48 +00:00
(WIP) agents (#171)
This commit is contained in:
1
tests/unit_tests/agents/__init__.py
Normal file
1
tests/unit_tests/agents/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Test agent functionality."""
|
@@ -2,8 +2,9 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.mrkl.base import ChainConfig, MRKLChain, get_action_and_input
|
||||
from langchain.chains.mrkl.prompt import BASE_TEMPLATE
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent, get_action_and_input
|
||||
from langchain.agents.mrkl.prompt import BASE_TEMPLATE
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.prompts import PromptTemplate
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
@@ -29,7 +30,7 @@ def test_get_final_answer() -> None:
|
||||
"Final Answer: 1994"
|
||||
)
|
||||
action, action_input = get_action_and_input(llm_output)
|
||||
assert action == "Final Answer: "
|
||||
assert action == "Final Answer"
|
||||
assert action_input == "1994"
|
||||
|
||||
|
||||
@@ -52,19 +53,15 @@ def test_bad_action_line() -> None:
|
||||
def test_from_chains() -> None:
|
||||
"""Test initializing from chains."""
|
||||
chain_configs = [
|
||||
ChainConfig(
|
||||
action_name="foo", action=lambda x: "foo", action_description="foobar1"
|
||||
),
|
||||
ChainConfig(
|
||||
action_name="bar", action=lambda x: "bar", action_description="foobar2"
|
||||
),
|
||||
Tool(name="foo", func=lambda x: "foo", description="foobar1"),
|
||||
Tool(name="bar", func=lambda x: "bar", description="foobar2"),
|
||||
]
|
||||
mrkl_chain = MRKLChain.from_chains(FakeLLM(), chain_configs)
|
||||
agent = ZeroShotAgent.from_llm_and_tools(FakeLLM(), chain_configs)
|
||||
expected_tools_prompt = "foo: foobar1\nbar: foobar2"
|
||||
expected_tool_names = "foo, bar"
|
||||
expected_template = BASE_TEMPLATE.format(
|
||||
tools=expected_tools_prompt, tool_names=expected_tool_names
|
||||
)
|
||||
prompt = mrkl_chain.prompt
|
||||
prompt = agent.llm_chain.prompt
|
||||
assert isinstance(prompt, PromptTemplate)
|
||||
assert prompt.template == expected_template
|
@@ -4,8 +4,8 @@ from typing import Any, List, Mapping, Optional, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.react.base import ReActChain, predict_until_observation
|
||||
from langchain.agents.react.base import ReActChain, ReActDocstoreAgent
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.docstore.base import Docstore
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.llms.base import LLM
|
||||
@@ -51,33 +51,32 @@ class FakeDocstore(Docstore):
|
||||
|
||||
def test_predict_until_observation_normal() -> None:
|
||||
"""Test predict_until_observation when observation is made normally."""
|
||||
outputs = ["foo\nAction 1: search[foo]"]
|
||||
outputs = ["foo\nAction 1: Search[foo]"]
|
||||
fake_llm = FakeListLLM(outputs)
|
||||
fake_llm_chain = LLMChain(llm=fake_llm, prompt=_FAKE_PROMPT)
|
||||
ret_text, action, directive = predict_until_observation(fake_llm_chain, "", 1)
|
||||
assert ret_text == outputs[0]
|
||||
assert action == "search"
|
||||
assert directive == "foo"
|
||||
tools = [
|
||||
Tool("Search", lambda x: x),
|
||||
Tool("Lookup", lambda x: x),
|
||||
]
|
||||
agent = ReActDocstoreAgent.from_llm_and_tools(fake_llm, tools)
|
||||
output = agent.get_action("")
|
||||
assert output.log == outputs[0]
|
||||
assert output.tool == "Search"
|
||||
assert output.tool_input == "foo"
|
||||
|
||||
|
||||
def test_predict_until_observation_repeat() -> None:
|
||||
"""Test when no action is generated initially."""
|
||||
outputs = ["foo", " search[foo]"]
|
||||
outputs = ["foo", " Search[foo]"]
|
||||
fake_llm = FakeListLLM(outputs)
|
||||
fake_llm_chain = LLMChain(llm=fake_llm, prompt=_FAKE_PROMPT)
|
||||
ret_text, action, directive = predict_until_observation(fake_llm_chain, "", 1)
|
||||
assert ret_text == "foo\nAction 1: search[foo]"
|
||||
assert action == "search"
|
||||
assert directive == "foo"
|
||||
|
||||
|
||||
def test_predict_until_observation_error() -> None:
|
||||
"""Test handling of generation of text that cannot be parsed."""
|
||||
outputs = ["foo\nAction 1: foo"]
|
||||
fake_llm = FakeListLLM(outputs)
|
||||
fake_llm_chain = LLMChain(llm=fake_llm, prompt=_FAKE_PROMPT)
|
||||
with pytest.raises(ValueError):
|
||||
predict_until_observation(fake_llm_chain, "", 1)
|
||||
tools = [
|
||||
Tool("Search", lambda x: x),
|
||||
Tool("Lookup", lambda x: x),
|
||||
]
|
||||
agent = ReActDocstoreAgent.from_llm_and_tools(fake_llm, tools)
|
||||
output = agent.get_action("")
|
||||
assert output.log == "foo\nAction 1: Search[foo]"
|
||||
assert output.tool == "Search"
|
||||
assert output.tool_input == "foo"
|
||||
|
||||
|
||||
def test_react_chain() -> None:
|
||||
@@ -89,9 +88,8 @@ def test_react_chain() -> None:
|
||||
]
|
||||
fake_llm = FakeListLLM(responses)
|
||||
react_chain = ReActChain(llm=fake_llm, docstore=FakeDocstore())
|
||||
inputs = {"question": "when was langchain made"}
|
||||
output = react_chain(inputs)
|
||||
assert output["answer"] == "2022"
|
||||
output = react_chain.run("when was langchain made")
|
||||
assert output == "2022"
|
||||
|
||||
|
||||
def test_react_chain_bad_action() -> None:
|
||||
@@ -101,5 +99,5 @@ def test_react_chain_bad_action() -> None:
|
||||
]
|
||||
fake_llm = FakeListLLM(responses)
|
||||
react_chain = ReActChain(llm=fake_llm, docstore=FakeDocstore())
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(KeyError):
|
||||
react_chain.run("when was langchain made")
|
Reference in New Issue
Block a user