From 314a098fb6d4cc190027095532b9656f68fd0692 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 21 Nov 2022 12:03:29 -0800 Subject: [PATCH] cr --- tests/unit_tests/routing_chains/test_mrkl.py | 2 +- tests/unit_tests/routing_chains/test_react.py | 25 +++++++++++++------ 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/tests/unit_tests/routing_chains/test_mrkl.py b/tests/unit_tests/routing_chains/test_mrkl.py index 66781088438..74ad7c85ee6 100644 --- a/tests/unit_tests/routing_chains/test_mrkl.py +++ b/tests/unit_tests/routing_chains/test_mrkl.py @@ -30,7 +30,7 @@ def test_get_final_answer() -> None: "Final Answer: 1994" ) action, action_input = get_action_and_input(llm_output) - assert action == "Final Answer: " + assert action == "Final Answer" assert action_input == "1994" diff --git a/tests/unit_tests/routing_chains/test_react.py b/tests/unit_tests/routing_chains/test_react.py index 9e82650040c..6d9b35b2e94 100644 --- a/tests/unit_tests/routing_chains/test_react.py +++ b/tests/unit_tests/routing_chains/test_react.py @@ -9,6 +9,7 @@ from langchain.docstore.document import Document from langchain.llms.base import LLM from langchain.prompts.prompt import PromptTemplate from langchain.routing_chains.react.base import ReActChain, ReActDocstoreRouter +from langchain.routing_chains.tools import Tool _PAGE_CONTENT = """This is a page about LangChain. @@ -50,23 +51,31 @@ class FakeDocstore(Docstore): def test_predict_until_observation_normal() -> None: """Test predict_until_observation when observation is made normally.""" - outputs = ["foo\nAction 1: search[foo]"] + outputs = ["foo\nAction 1: Search[foo]"] fake_llm = FakeListLLM(outputs) - router_chain = ReActDocstoreRouter(llm=fake_llm) + tools = [ + Tool("Search", lambda x: x), + Tool("Lookup", lambda x: x), + ] + router_chain = ReActDocstoreRouter.from_llm_and_tools(fake_llm, tools) output = router_chain.route("") assert output.log == outputs[0] - assert output.tool == "search" + assert output.tool == "Search" assert output.tool_input == "foo" def test_predict_until_observation_repeat() -> None: """Test when no action is generated initially.""" - outputs = ["foo", " search[foo]"] + outputs = ["foo", " Search[foo]"] fake_llm = FakeListLLM(outputs) - router_chain = ReActDocstoreRouter(llm=fake_llm) + tools = [ + Tool("Search", lambda x: x), + Tool("Lookup", lambda x: x), + ] + router_chain = ReActDocstoreRouter.from_llm_and_tools(fake_llm, tools) output = router_chain.route("") - assert output.log == "foo\nAction 1: search[foo]" - assert output.tool == "search" + assert output.log == "foo\nAction 1: Search[foo]" + assert output.tool == "Search" assert output.tool_input == "foo" @@ -91,5 +100,5 @@ def test_react_chain_bad_action() -> None: ] fake_llm = FakeListLLM(responses) react_chain = ReActChain(llm=fake_llm, docstore=FakeDocstore()) - with pytest.raises(ValueError): + with pytest.raises(KeyError): react_chain.run("when was langchain made")