mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-12 12:59:07 +00:00
establish router
This commit is contained in:
@@ -2,7 +2,11 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.mrkl.base import ChainConfig, MRKLChain, get_action_and_input
|
||||
from langchain.chains.mrkl.base import (
|
||||
ChainConfig,
|
||||
MRKLRouterChain,
|
||||
get_action_and_input,
|
||||
)
|
||||
from langchain.chains.mrkl.prompt import BASE_TEMPLATE
|
||||
from langchain.prompts import Prompt
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
@@ -59,12 +63,12 @@ def test_from_chains() -> None:
|
||||
action_name="bar", action=lambda x: "bar", action_description="foobar2"
|
||||
),
|
||||
]
|
||||
mrkl_chain = MRKLChain.from_chains(FakeLLM(), chain_configs)
|
||||
router_chain = MRKLRouterChain(FakeLLM(), chain_configs)
|
||||
expected_tools_prompt = "foo: foobar1\nbar: foobar2"
|
||||
expected_tool_names = "foo, bar"
|
||||
expected_template = BASE_TEMPLATE.format(
|
||||
tools=expected_tools_prompt, tool_names=expected_tool_names
|
||||
)
|
||||
prompt = mrkl_chain.prompt
|
||||
prompt = router_chain.llm_chain.prompt
|
||||
assert isinstance(prompt, Prompt)
|
||||
assert prompt.template == expected_template
|
||||
|
@@ -4,8 +4,7 @@ from typing import Any, List, Mapping, Optional, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.react.base import ReActChain, predict_until_observation
|
||||
from langchain.chains.react.base import ReActChain, ReActRouterChain
|
||||
from langchain.docstore.base import Docstore
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.llms.base import LLM
|
||||
@@ -53,8 +52,8 @@ def test_predict_until_observation_normal() -> None:
|
||||
"""Test predict_until_observation when observation is made normally."""
|
||||
outputs = ["foo\nAction 1: search[foo]"]
|
||||
fake_llm = FakeListLLM(outputs)
|
||||
fake_llm_chain = LLMChain(llm=fake_llm, prompt=_FAKE_PROMPT)
|
||||
ret_text, action, directive = predict_until_observation(fake_llm_chain, "", 1)
|
||||
router_chain = ReActRouterChain(llm=fake_llm)
|
||||
action, directive, ret_text = router_chain.get_action_and_input("")
|
||||
assert ret_text == outputs[0]
|
||||
assert action == "search"
|
||||
assert directive == "foo"
|
||||
@@ -64,22 +63,13 @@ def test_predict_until_observation_repeat() -> None:
|
||||
"""Test when no action is generated initially."""
|
||||
outputs = ["foo", " search[foo]"]
|
||||
fake_llm = FakeListLLM(outputs)
|
||||
fake_llm_chain = LLMChain(llm=fake_llm, prompt=_FAKE_PROMPT)
|
||||
ret_text, action, directive = predict_until_observation(fake_llm_chain, "", 1)
|
||||
router_chain = ReActRouterChain(llm=fake_llm)
|
||||
action, directive, ret_text = router_chain.get_action_and_input("")
|
||||
assert ret_text == "foo\nAction 1: search[foo]"
|
||||
assert action == "search"
|
||||
assert directive == "foo"
|
||||
|
||||
|
||||
def test_predict_until_observation_error() -> None:
|
||||
"""Test handling of generation of text that cannot be parsed."""
|
||||
outputs = ["foo\nAction 1: foo"]
|
||||
fake_llm = FakeListLLM(outputs)
|
||||
fake_llm_chain = LLMChain(llm=fake_llm, prompt=_FAKE_PROMPT)
|
||||
with pytest.raises(ValueError):
|
||||
predict_until_observation(fake_llm_chain, "", 1)
|
||||
|
||||
|
||||
def test_react_chain() -> None:
|
||||
"""Test react chain."""
|
||||
responses = [
|
||||
|
Reference in New Issue
Block a user