diff --git a/docs/examples/demos/self_ask_with_search.ipynb b/docs/examples/demos/self_ask_with_search.ipynb index 6cbc7c97054..bf271cb3fe1 100644 --- a/docs/examples/demos/self_ask_with_search.ipynb +++ b/docs/examples/demos/self_ask_with_search.ipynb @@ -20,9 +20,6 @@ "name": "stdout", "output_type": "stream", "text": [ - "\n", - "\n", - "\u001b[1m> Entering new chain...\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new chain...\u001b[0m\n", @@ -33,8 +30,6 @@ "\u001b[32;1m\u001b[1;3mFollow up: Where is Carlos Alcaraz from?\u001b[0m\n", "Intermediate answer: \u001b[36;1m\u001b[1;3mEl Palmar, Spain\u001b[0m\n", "\u001b[32;1m\u001b[1;3mSo the final answer is: El Palmar, Spain\u001b[0m\n", - "\u001b[1m> Finished chain.\u001b[0m\n", - "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] }, diff --git a/langchain/__init__.py b/langchain/__init__.py index 1591eb8f31b..d567d63eee6 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -8,10 +8,7 @@ with open(Path(__file__).absolute().parents[0] / "VERSION") as _f: from langchain.chains import ( LLMChain, LLMMathChain, - MRKLChain, PythonChain, - ReActChain, - SelfAskWithSearchChain, SerpAPIChain, SQLDatabaseChain, VectorDBQA, @@ -19,6 +16,7 @@ from langchain.chains import ( from langchain.docstore import Wikipedia from langchain.llms import Cohere, HuggingFaceHub, OpenAI from langchain.prompts import BasePromptTemplate, PromptTemplate +from langchain.smart_chains import MRKLChain, ReActChain, SelfAskWithSearchChain from langchain.sql_database import SQLDatabase from langchain.vectorstores import FAISS, ElasticVectorSearch diff --git a/langchain/chains/__init__.py b/langchain/chains/__init__.py index ae27d37ea3a..41654ed9bbe 100644 --- a/langchain/chains/__init__.py +++ b/langchain/chains/__init__.py @@ -1,10 +1,7 @@ """Chains are easily reusable components which can be linked together.""" from langchain.chains.llm import LLMChain from langchain.chains.llm_math.base import LLMMathChain -from langchain.chains.mrkl.base import MRKLChain from langchain.chains.python import PythonChain -from langchain.chains.react.base import ReActChain -from langchain.chains.self_ask_with_search.base import SelfAskWithSearchChain from langchain.chains.serpapi import SerpAPIChain from langchain.chains.sql_database.base import SQLDatabaseChain from langchain.chains.vector_db_qa.base import VectorDBQA @@ -13,10 +10,7 @@ __all__ = [ "LLMChain", "LLMMathChain", "PythonChain", - "SelfAskWithSearchChain", "SerpAPIChain", - "ReActChain", "SQLDatabaseChain", - "MRKLChain", "VectorDBQA", ] diff --git a/langchain/smart_chains/__init__.py b/langchain/smart_chains/__init__.py new file mode 100644 index 00000000000..1b049008ef0 --- /dev/null +++ b/langchain/smart_chains/__init__.py @@ -0,0 +1,6 @@ +"""Smart chains.""" +from langchain.smart_chains.mrkl.base import MRKLChain +from langchain.smart_chains.react.base import ReActChain +from langchain.smart_chains.self_ask_with_search.base import SelfAskWithSearchChain + +__all__ = ["MRKLChain", "SelfAskWithSearchChain", "ReActChain"] diff --git a/langchain/chains/mrkl/__init__.py b/langchain/smart_chains/mrkl/__init__.py similarity index 100% rename from langchain/chains/mrkl/__init__.py rename to langchain/smart_chains/mrkl/__init__.py diff --git a/langchain/chains/mrkl/base.py b/langchain/smart_chains/mrkl/base.py similarity index 65% rename from langchain/chains/mrkl/base.py rename to langchain/smart_chains/mrkl/base.py index 3aa63b5066f..db357ae1d9c 100644 --- a/langchain/chains/mrkl/base.py +++ b/langchain/smart_chains/mrkl/base.py @@ -1,16 +1,12 @@ """Attempt to implement MRKL systems as described in arxiv.org/pdf/2205.00445.pdf.""" -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple +from typing import Any, Callable, List, NamedTuple, Optional, Tuple -from pydantic import BaseModel, Extra - -from langchain.chains.base import Chain from langchain.chains.llm import LLMChain -from langchain.chains.mrkl.prompt import BASE_TEMPLATE -from langchain.chains.router import LLMRouterChain -from langchain.input import ChainedInput, get_color_mapping from langchain.llms.base import LLM -from langchain.prompts import BasePromptTemplate, PromptTemplate -from langchain.chains.router_expert import RouterExpertChain, ExpertConfig +from langchain.prompts import PromptTemplate +from langchain.smart_chains.mrkl.prompt import BASE_TEMPLATE +from langchain.smart_chains.router import LLMRouterChain +from langchain.smart_chains.router_expert import ExpertConfig, RouterExpertChain FINAL_ANSWER_ACTION = "Final Answer: " @@ -79,33 +75,20 @@ class MRKLRouterChain(LLMRouterChain): return get_action_and_input(text) -class MRKLChain(Chain, BaseModel): +class MRKLChain(RouterExpertChain): """Chain that implements the MRKL system. Example: .. code-block:: python - from langchain import OpenAI, Prompt, MRKLChain + from langchain import OpenAI, MRKLChain from langchain.chains.mrkl.base import ChainConfig llm = OpenAI(temperature=0) prompt = PromptTemplate(...) - action_to_chain_map = {...} - mrkl = MRKLChain( - llm=llm, - prompt=prompt, - action_to_chain_map=action_to_chain_map - ) + chains = [...] + mrkl = MRKLChain.from_chains(llm=llm, prompt=prompt) """ - llm: LLM - """LLM wrapper to use as router.""" - chain_configs: List[ChainConfig] - """Chain configs this chain has access to.""" - action_to_chain_map: Dict[str, Callable] - """Mapping from action name to chain to execute.""" - input_key: str = "question" #: :meta private: - output_key: str = "answer" #: :meta private: - @classmethod def from_chains( cls, llm: LLM, chains: List[ChainConfig], **kwargs: Any @@ -145,47 +128,8 @@ class MRKLChain(Chain, BaseModel): ] mrkl = MRKLChain.from_chains(llm, chains) """ - action_to_chain_map = {chain.action_name: chain.action for chain in chains} - return cls( - llm=llm, - chain_configs=chains, - action_to_chain_map=action_to_chain_map, - **kwargs, - ) - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - @property - def input_keys(self) -> List[str]: - """Expect input key. - - :meta private: - """ - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - """Expect output key. - - :meta private: - """ - return [self.output_key] - - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: - router_chain = MRKLRouterChain(self.llm, self.chain_configs) - question = inputs[self.input_key] + router_chain = MRKLRouterChain(llm, chains) expert_configs = [ - ExpertConfig(expert_name=c.action_name, expert=c.action) - for c in self.chain_configs + ExpertConfig(expert_name=c.action_name, expert=c.action) for c in chains ] - chain = RouterExpertChain( - router_chain=router_chain, - expert_configs=expert_configs, - verbose=self.verbose - ) - output = chain.run(question) - return {self.output_key: output} + return cls(router_chain=router_chain, expert_configs=expert_configs, **kwargs) diff --git a/langchain/chains/mrkl/prompt.py b/langchain/smart_chains/mrkl/prompt.py similarity index 100% rename from langchain/chains/mrkl/prompt.py rename to langchain/smart_chains/mrkl/prompt.py diff --git a/langchain/chains/react/__init__.py b/langchain/smart_chains/react/__init__.py similarity index 100% rename from langchain/chains/react/__init__.py rename to langchain/smart_chains/react/__init__.py diff --git a/langchain/chains/react/base.py b/langchain/smart_chains/react/base.py similarity index 61% rename from langchain/chains/react/base.py rename to langchain/smart_chains/react/base.py index fd11fcd4533..44db72015e5 100644 --- a/langchain/chains/react/base.py +++ b/langchain/smart_chains/react/base.py @@ -1,18 +1,16 @@ """Chain that implements the ReAct paper from https://arxiv.org/pdf/2210.03629.pdf.""" import re -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional, Tuple -from pydantic import BaseModel, Extra +from pydantic import BaseModel -from langchain.chains.base import Chain from langchain.chains.llm import LLMChain -from langchain.chains.react.prompt import PROMPT -from langchain.chains.router import LLMRouterChain from langchain.docstore.base import Docstore from langchain.docstore.document import Document -from langchain.input import ChainedInput from langchain.llms.base import LLM -from langchain.chains.router_expert import RouterExpertChain, ExpertConfig +from langchain.smart_chains.react.prompt import PROMPT +from langchain.smart_chains.router import LLMRouterChain +from langchain.smart_chains.router_expert import ExpertConfig, RouterExpertChain class ReActRouterChain(LLMRouterChain, BaseModel): @@ -46,7 +44,7 @@ class ReActRouterChain(LLMRouterChain, BaseModel): @property def finish_action_name(self) -> str: - """The action name of when to finish the chain.""" + """Name of the action of when to finish the chain.""" return "Finish" @property @@ -61,12 +59,15 @@ class ReActRouterChain(LLMRouterChain, BaseModel): class DocstoreExplorer: + """Class to assist with exploration of a document store.""" def __init__(self, docstore: Docstore): - self.docstore=docstore - self.document = None + """Initialize with a docstore, and set initial document to None.""" + self.docstore = docstore + self.document: Optional[Document] = None - def search(self, term: str): + def search(self, term: str) -> str: + """Search for a term in the docstore, and if found save.""" result = self.docstore.search(term) if isinstance(result, Document): self.document = result @@ -75,13 +76,14 @@ class DocstoreExplorer: self.document = None return result - def lookup(self, term: str): + def lookup(self, term: str) -> str: + """Lookup a term in document (if saved).""" if self.document is None: raise ValueError("Cannot lookup without a successful search first") return self.document.lookup(term) -class ReActChain(Chain, BaseModel): +class ReActChain(RouterExpertChain): """Chain that implements the ReAct paper. Example: @@ -91,47 +93,14 @@ class ReActChain(Chain, BaseModel): react = ReAct(llm=OpenAI()) """ - llm: LLM - """LLM wrapper to use.""" - docstore: Docstore - """Docstore to use.""" - input_key: str = "question" #: :meta private: - output_key: str = "answer" #: :meta private: - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - @property - def input_keys(self) -> List[str]: - """Expect input key. - - :meta private: - """ - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - """Expect output key. - - :meta private: - """ - return [self.output_key] - - def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: - question = inputs[self.input_key] - router_chain = ReActRouterChain(self.llm) - docstore_explorer = DocstoreExplorer(self.docstore) + def __init__(self, llm: LLM, docstore: Docstore, **kwargs: Any): + """Initialize with the LLM and a docstore.""" + router_chain = ReActRouterChain(llm) + docstore_explorer = DocstoreExplorer(docstore) expert_configs = [ ExpertConfig(expert_name="Search", expert=docstore_explorer.search), - ExpertConfig(expert_name="Lookup", expert=docstore_explorer.lookup) + ExpertConfig(expert_name="Lookup", expert=docstore_explorer.lookup), ] - chain = RouterExpertChain( - router_chain=router_chain, - expert_configs=expert_configs, - verbose=self.verbose + super().__init__( + router_chain=router_chain, expert_configs=expert_configs, **kwargs ) - output = chain.run(question) - return {self.output_key: output} diff --git a/langchain/chains/react/prompt.py b/langchain/smart_chains/react/prompt.py similarity index 100% rename from langchain/chains/react/prompt.py rename to langchain/smart_chains/react/prompt.py diff --git a/langchain/chains/router.py b/langchain/smart_chains/router.py similarity index 97% rename from langchain/chains/router.py rename to langchain/smart_chains/router.py index c8adc5a8b88..1593607171b 100644 --- a/langchain/chains/router.py +++ b/langchain/smart_chains/router.py @@ -48,7 +48,7 @@ class RouterChain(Chain, BaseModel, ABC): @property def finish_action_name(self) -> str: - """The action name of when to finish the chain.""" + """Name of the action of when to finish the chain.""" return "Final Answer" def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: diff --git a/langchain/chains/router_expert.py b/langchain/smart_chains/router_expert.py similarity index 80% rename from langchain/chains/router_expert.py rename to langchain/smart_chains/router_expert.py index e268cb9731a..408e9ad8cc4 100644 --- a/langchain/chains/router_expert.py +++ b/langchain/smart_chains/router_expert.py @@ -1,21 +1,15 @@ """Router-Expert framework.""" -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple +from typing import Callable, Dict, List, NamedTuple from pydantic import BaseModel, Extra from langchain.chains.base import Chain -from langchain.chains.llm import LLMChain -from langchain.chains.mrkl.prompt import BASE_TEMPLATE -from langchain.chains.router import LLMRouterChain from langchain.input import ChainedInput, get_color_mapping -from langchain.llms.base import LLM -from langchain.prompts import BasePromptTemplate, PromptTemplate -from langchain.chains.router import RouterChain - - +from langchain.smart_chains.router import RouterChain class ExpertConfig(NamedTuple): + """Configuration for experts.""" expert_name: str expert: Callable[[str], str] @@ -57,8 +51,14 @@ class RouterExpertChain(Chain, BaseModel): def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: action_to_chain_map = {e.expert_name: e.expert for e in self.expert_configs} + starter_string = ( + inputs[self.input_key] + + self.starter_string + + self.router_chain.router_prefix + ) chained_input = ChainedInput( - f"{inputs[self.input_key]}{self.starter_string}{self.router_chain.router_prefix}", verbose=self.verbose + starter_string, + verbose=self.verbose, ) color_mapping = get_color_mapping( [c.expert_name for c in self.expert_configs], excluded_colors=["green"] @@ -74,4 +74,4 @@ class RouterExpertChain(Chain, BaseModel): ca = chain(action_input) chained_input.add(f"\n{self.router_chain.observation_prefix}") chained_input.add(ca, color=color_mapping[action]) - chained_input.add(f"\n{self.router_chain.router_prefix}") \ No newline at end of file + chained_input.add(f"\n{self.router_chain.router_prefix}") diff --git a/langchain/chains/self_ask_with_search/__init__.py b/langchain/smart_chains/self_ask_with_search/__init__.py similarity index 100% rename from langchain/chains/self_ask_with_search/__init__.py rename to langchain/smart_chains/self_ask_with_search/__init__.py diff --git a/langchain/chains/self_ask_with_search/base.py b/langchain/smart_chains/self_ask_with_search/base.py similarity index 53% rename from langchain/chains/self_ask_with_search/base.py rename to langchain/smart_chains/self_ask_with_search/base.py index ed6efee05f8..fa501b783e6 100644 --- a/langchain/chains/self_ask_with_search/base.py +++ b/langchain/smart_chains/self_ask_with_search/base.py @@ -1,16 +1,12 @@ """Chain that does self ask with search.""" -from typing import Any, Dict, List, Tuple +from typing import Any, Tuple -from pydantic import BaseModel, Extra - -from langchain.chains.base import Chain from langchain.chains.llm import LLMChain -from langchain.chains.router import LLMRouterChain -from langchain.chains.self_ask_with_search.prompt import PROMPT from langchain.chains.serpapi import SerpAPIChain -from langchain.input import ChainedInput from langchain.llms.base import LLM -from langchain.chains.router_expert import RouterExpertChain, ExpertConfig +from langchain.smart_chains.router import LLMRouterChain +from langchain.smart_chains.router_expert import ExpertConfig, RouterExpertChain +from langchain.smart_chains.self_ask_with_search.prompt import PROMPT class SelfAskWithSearchRouter(LLMRouterChain): @@ -32,7 +28,7 @@ class SelfAskWithSearchRouter(LLMRouterChain): finish_string = "So the final answer is: " if finish_string not in last_line: raise ValueError("We should probably never get here") - return "Final Answer", text[len(finish_string):] + return "Final Answer", text[len(finish_string) :] if ":" not in last_line: after_colon = last_line @@ -57,7 +53,7 @@ class SelfAskWithSearchRouter(LLMRouterChain): return "" -class SelfAskWithSearchChain(Chain, BaseModel): +class SelfAskWithSearchChain(RouterExpertChain): """Chain that does self ask with search. Example: @@ -68,39 +64,16 @@ class SelfAskWithSearchChain(Chain, BaseModel): self_ask = SelfAskWithSearchChain(llm=OpenAI(), search_chain=search_chain) """ - llm: LLM - """LLM wrapper to use.""" - search_chain: SerpAPIChain - """Search chain to use.""" - input_key: str = "question" #: :meta private: - output_key: str = "answer" #: :meta private: - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - @property - def input_keys(self) -> List[str]: - """Expect input key. - - :meta private: - """ - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - """Expect output key. - - :meta private: - """ - return [self.output_key] - - def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: + def __init__(self, llm: LLM, search_chain: SerpAPIChain, **kwargs: Any): + """Initialize with just an LLM and a search chain.""" intermediate = "\nIntermediate answer:" - router = SelfAskWithSearchRouter(self.llm, stops=[intermediate]) - expert_configs = [ExpertConfig(expert_name="Intermediate Answer", expert=self.search_chain.run)] - chain = RouterExpertChain(router_chain=router, expert_configs=expert_configs, verbose=self.verbose, starter_string="\nAre follow up questions needed here:") - output = chain.run(inputs[self.input_key]) - return {self.output_key: output} + router = SelfAskWithSearchRouter(llm, stops=[intermediate]) + expert_configs = [ + ExpertConfig(expert_name="Intermediate Answer", expert=search_chain.run) + ] + super().__init__( + router_chain=router, + expert_configs=expert_configs, + starter_string="\nAre follow up questions needed here:", + **kwargs + ) diff --git a/langchain/chains/self_ask_with_search/prompt.py b/langchain/smart_chains/self_ask_with_search/prompt.py similarity index 100% rename from langchain/chains/self_ask_with_search/prompt.py rename to langchain/smart_chains/self_ask_with_search/prompt.py diff --git a/tests/integration_tests/chains/test_react.py b/tests/integration_tests/chains/test_react.py index 500b5cea9ff..3414701efe4 100644 --- a/tests/integration_tests/chains/test_react.py +++ b/tests/integration_tests/chains/test_react.py @@ -1,8 +1,8 @@ """Integration test for self ask with search.""" -from langchain.chains.react.base import ReActChain from langchain.docstore.wikipedia import Wikipedia from langchain.llms.openai import OpenAI +from langchain.smart_chains.react.base import ReActChain def test_react() -> None: diff --git a/tests/integration_tests/chains/test_self_ask_with_search.py b/tests/integration_tests/chains/test_self_ask_with_search.py index 7ec49a8b0cd..afc58ab8381 100644 --- a/tests/integration_tests/chains/test_self_ask_with_search.py +++ b/tests/integration_tests/chains/test_self_ask_with_search.py @@ -1,7 +1,7 @@ """Integration test for self ask with search.""" -from langchain.chains.self_ask_with_search.base import SelfAskWithSearchChain from langchain.chains.serpapi import SerpAPIChain from langchain.llms.openai import OpenAI +from langchain.smart_chains.self_ask_with_search.base import SelfAskWithSearchChain def test_self_ask_with_search() -> None: diff --git a/tests/unit_tests/smart_chains/__init__.py b/tests/unit_tests/smart_chains/__init__.py new file mode 100644 index 00000000000..4b80a982738 --- /dev/null +++ b/tests/unit_tests/smart_chains/__init__.py @@ -0,0 +1 @@ +"""Test smart chain functionality.""" diff --git a/tests/unit_tests/chains/test_mrkl.py b/tests/unit_tests/smart_chains/test_mrkl.py similarity index 95% rename from tests/unit_tests/chains/test_mrkl.py rename to tests/unit_tests/smart_chains/test_mrkl.py index 5764351357f..fb95a04ae03 100644 --- a/tests/unit_tests/chains/test_mrkl.py +++ b/tests/unit_tests/smart_chains/test_mrkl.py @@ -2,13 +2,13 @@ import pytest -from langchain.chains.mrkl.base import ( +from langchain.prompts import PromptTemplate +from langchain.smart_chains.mrkl.base import ( ChainConfig, MRKLRouterChain, get_action_and_input, ) -from langchain.chains.mrkl.prompt import BASE_TEMPLATE -from langchain.prompts import PromptTemplate +from langchain.smart_chains.mrkl.prompt import BASE_TEMPLATE from tests.unit_tests.llms.fake_llm import FakeLLM diff --git a/tests/unit_tests/chains/test_react.py b/tests/unit_tests/smart_chains/test_react.py similarity index 97% rename from tests/unit_tests/chains/test_react.py rename to tests/unit_tests/smart_chains/test_react.py index c9a8ad14375..31d47fb1dbd 100644 --- a/tests/unit_tests/chains/test_react.py +++ b/tests/unit_tests/smart_chains/test_react.py @@ -4,11 +4,11 @@ from typing import Any, List, Mapping, Optional, Union import pytest -from langchain.chains.react.base import ReActChain, ReActRouterChain from langchain.docstore.base import Docstore from langchain.docstore.document import Document from langchain.llms.base import LLM from langchain.prompts.prompt import PromptTemplate +from langchain.smart_chains.react.base import ReActChain, ReActRouterChain _PAGE_CONTENT = """This is a page about LangChain.