establish router

This commit is contained in:
Harrison Chase
2022-11-19 17:01:52 -08:00
parent 8869b0ab0e
commit 6f55fa8ba7
9 changed files with 228 additions and 186 deletions

View File

@@ -2,7 +2,11 @@
import pytest
from langchain.chains.mrkl.base import ChainConfig, MRKLChain, get_action_and_input
from langchain.chains.mrkl.base import (
ChainConfig,
MRKLRouterChain,
get_action_and_input,
)
from langchain.chains.mrkl.prompt import BASE_TEMPLATE
from langchain.prompts import Prompt
from tests.unit_tests.llms.fake_llm import FakeLLM
@@ -59,12 +63,12 @@ def test_from_chains() -> None:
action_name="bar", action=lambda x: "bar", action_description="foobar2"
),
]
mrkl_chain = MRKLChain.from_chains(FakeLLM(), chain_configs)
router_chain = MRKLRouterChain(FakeLLM(), chain_configs)
expected_tools_prompt = "foo: foobar1\nbar: foobar2"
expected_tool_names = "foo, bar"
expected_template = BASE_TEMPLATE.format(
tools=expected_tools_prompt, tool_names=expected_tool_names
)
prompt = mrkl_chain.prompt
prompt = router_chain.llm_chain.prompt
assert isinstance(prompt, Prompt)
assert prompt.template == expected_template

View File

@@ -4,8 +4,7 @@ from typing import Any, List, Mapping, Optional, Union
import pytest
from langchain.chains.llm import LLMChain
from langchain.chains.react.base import ReActChain, predict_until_observation
from langchain.chains.react.base import ReActChain, ReActRouterChain
from langchain.docstore.base import Docstore
from langchain.docstore.document import Document
from langchain.llms.base import LLM
@@ -53,8 +52,8 @@ def test_predict_until_observation_normal() -> None:
"""Test predict_until_observation when observation is made normally."""
outputs = ["foo\nAction 1: search[foo]"]
fake_llm = FakeListLLM(outputs)
fake_llm_chain = LLMChain(llm=fake_llm, prompt=_FAKE_PROMPT)
ret_text, action, directive = predict_until_observation(fake_llm_chain, "", 1)
router_chain = ReActRouterChain(llm=fake_llm)
action, directive, ret_text = router_chain.get_action_and_input("")
assert ret_text == outputs[0]
assert action == "search"
assert directive == "foo"
@@ -64,22 +63,13 @@ def test_predict_until_observation_repeat() -> None:
"""Test when no action is generated initially."""
outputs = ["foo", " search[foo]"]
fake_llm = FakeListLLM(outputs)
fake_llm_chain = LLMChain(llm=fake_llm, prompt=_FAKE_PROMPT)
ret_text, action, directive = predict_until_observation(fake_llm_chain, "", 1)
router_chain = ReActRouterChain(llm=fake_llm)
action, directive, ret_text = router_chain.get_action_and_input("")
assert ret_text == "foo\nAction 1: search[foo]"
assert action == "search"
assert directive == "foo"
def test_predict_until_observation_error() -> None:
"""Test handling of generation of text that cannot be parsed."""
outputs = ["foo\nAction 1: foo"]
fake_llm = FakeListLLM(outputs)
fake_llm_chain = LLMChain(llm=fake_llm, prompt=_FAKE_PROMPT)
with pytest.raises(ValueError):
predict_until_observation(fake_llm_chain, "", 1)
def test_react_chain() -> None:
"""Test react chain."""
responses = [