diff --git a/docs/examples/demos/mrkl.ipynb b/docs/examples/demos/mrkl.ipynb index e9f364cbe3e..354a0066ee6 100644 --- a/docs/examples/demos/mrkl.ipynb +++ b/docs/examples/demos/mrkl.ipynb @@ -27,7 +27,7 @@ "outputs": [], "source": [ "from langchain import LLMMathChain, OpenAI, SerpAPIChain, MRKLChain, SQLDatabase, SQLDatabaseChain\n", - "from langchain.chains.mrkl.base import ChainConfig" + "from langchain.routing_chains.mrkl.base import ChainConfig" ] }, { @@ -167,7 +167,7 @@ "\n", "\u001b[1m> Entering new chain...\u001b[0m\n", "What albums by Alanis Morissette are in the FooBar database?\n", - "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Album.Title FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId WHERE Artist.Name = 'Alanis Morissette'\u001b[0m\n", + "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Album.Title FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId WHERE Artist.Name = \"Alanis Morissette\"\u001b[0m\n", "SQLResult: \u001b[33;1m\u001b[1;3m[('Jagged Little Pill',)]\u001b[0m\n", "Answer:\u001b[32;1m\u001b[1;3m Jagged Little Pill\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n", diff --git a/docs/examples/demos/react.ipynb b/docs/examples/demos/react.ipynb index f7a49d4a2c3..7b91f485256 100644 --- a/docs/examples/demos/react.ipynb +++ b/docs/examples/demos/react.ipynb @@ -41,10 +41,12 @@ "with.\n", "Action 1: Search[David Chanoff]\u001b[0m\n", "Observation 1: \u001b[36;1m\u001b[1;3mDavid Chanoff is a noted author of non-fiction work. His work has typically involved collaborations with the principal protagonist of the work concerned. His collaborators have included; Augustus A. White, Joycelyn Elders, Đoàn Văn Toại, William J. Crowe, Ariel Sharon, Kenneth Good and Felix Zandman. He has also written about a wide range of subjects including literary history, education and foreign for The Washington Post, The New Republic and The New York Times Magazine. He has published more than twelve books.\u001b[0m\n", - "Thought 2:\u001b[32;1m\u001b[1;3m The U.S. Navy admiral David Chanoff collaborated with is William J. Crowe.\n", + "Thought 2:\u001b[32;1m\u001b[1;3m The U.S. Navy admiral David Chanoff collaborated with is William J. Crowe. I\n", + "need to search him next.\n", "Action 2: Search[William J. Crowe]\u001b[0m\n", "Observation 2: \u001b[36;1m\u001b[1;3mWilliam James Crowe Jr. (January 2, 1925 – October 18, 2007) was a United States Navy admiral and diplomat who served as the 11th chairman of the Joint Chiefs of Staff under Presidents Ronald Reagan and George H. W. Bush, and as the ambassador to the United Kingdom and Chair of the Intelligence Oversight Board under President Bill Clinton.\u001b[0m\n", - "Thought 3:\u001b[32;1m\u001b[1;3m William J. Crowe served as the ambassador to the United Kingdom under President Bill Clinton.\n", + "Thought 3:\u001b[32;1m\u001b[1;3m William J. Crowe served as the ambassador to the United Kingdom under\n", + "President Bill Clinton. So the answer is Bill Clinton.\n", "Action 3: Finish[Bill Clinton]\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -68,7 +70,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3cb9d77c", + "id": "4ff64e81", "metadata": {}, "outputs": [], "source": [] diff --git a/docs/examples/demos/routing_chains.ipynb b/docs/examples/demos/routing_chains.ipynb index 74915eced45..00b2df143f7 100644 --- a/docs/examples/demos/routing_chains.ipynb +++ b/docs/examples/demos/routing_chains.ipynb @@ -13,7 +13,7 @@ }, { "cell_type": "markdown", - "id": "4b21ae68", + "id": "3c6226b9", "metadata": {}, "source": [ "## Concepts\n", @@ -26,20 +26,19 @@ }, { "cell_type": "markdown", - "id": "c9677868", + "id": "05d4b21e", "metadata": {}, "source": [ "## Tools\n", - "When constructing your own Routing Chain, you will need to provide it with a list of tools that it can use. This is done with a list of ToolConfigs. The ToolConfig is used not only to create the Routing Chain, but is also sometimes used to create the router itself (often, the router logic depends on the tools available). \n", + "When constructing your own Routing Chain, you will need to provide it with a list of tools that it can use. This is done with a list of Tools. The Tools are used not only to create the Routing Chain, but is also sometimes used to create the router itself (often, the router logic depends on the tools available). \n", "\n", "```python\n", - "class ToolConfig(NamedTuple):\n", - " \"\"\"Configuration for tools.\"\"\"\n", + "class Tool(NamedTuple):\n", + " \"\"\"Interface for tools.\"\"\"\n", "\n", - " tool_name: str\n", - " tool: Callable[[str], str]\n", - " # Needed to construct some routers.\n", - " tool_description: Optional[str] = None\n", + " name: str\n", + " func: Callable[[str], str]\n", + " description: Optional[str] = None\n", "```\n", "\n", "The two required components of a ToolConfig are the name and then the tool itself. A tool description is optional, as it is needed for some routers but not all." @@ -50,22 +49,123 @@ "id": "2558a02d", "metadata": {}, "source": [ - "## Loading the chains\n", - "\n", - "```python\n", - "from langchain.routing_chains import load_chain\n", - "\n", - "tools: List[ToolConfig] = [...]\n", - "llm: LLM = OpenAI(temperature=0)\n", - "router_type: str = \"zero-shot\"\n", - "chain = load_chain(tools, llm, router_type, verbose=True)\n", - "```" + "## Loading the chains\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "36ed392e", + "metadata": {}, + "outputs": [], + "source": [ + "# Import things that are needed generically\n", + "from langchain.routing_chains import load_routing_chain, Tool\n", + "from langchain.llms import OpenAI" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "56ff7670", + "metadata": {}, + "outputs": [], + "source": [ + "# Load the tool configs that are needed.\n", + "from langchain import LLMMathChain, SerpAPIChain\n", + "llm = OpenAI(temperature=0)\n", + "search = SerpAPIChain()\n", + "llm_math_chain = LLMMathChain(llm=llm, verbose=True)\n", + "tools = [\n", + " Tool(\n", + " name = \"Search\",\n", + " func=search.run,\n", + " description=\"useful for when you need to answer questions about current events\"\n", + " ),\n", + " Tool(\n", + " name=\"Calculator\",\n", + " func=llm_math_chain.run,\n", + " description=\"useful for when you need to answer questions about math\"\n", + " )\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5b93047d", + "metadata": {}, + "outputs": [], + "source": [ + "# Construct the routing chain. We will use the default router type here.\n", + "# See documentation for a full list of options.\n", + "router_llm = OpenAI(temperature=0)\n", + "chain = load_routing_chain(tools, router_llm, verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "6f96a891", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "What is the age of Olivia Wilde's boyfriend raised to the 0.23 power?\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find the age of Olivia Wilde's boyfriend\n", + "Action: Search\n", + "Action Input: \"Olivia Wilde's boyfriend\"\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mOlivia Wilde started dating Harry Styles after ending her years-long engagement to Jason Sudeikis — see their relationship timeline.\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find the age of Harry Styles\n", + "Action: Search\n", + "Action Input: \"Harry Styles age\"\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3m28 years\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 28 to the 0.23 power\n", + "Action: Calculator\n", + "Action Input: 28^0.23\u001b[0m\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "28^0.23\u001b[32;1m\u001b[1;3m\n", + "\n", + "```python\n", + "print(28**0.23)\n", + "```\n", + "\u001b[0m\n", + "Answer: \u001b[33;1m\u001b[1;3m2.1520202182226886\n", + "\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "Observation: \u001b[33;1m\u001b[1;3mAnswer: 2.1520202182226886\n", + "\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", + "Final Answer: 2.1520202182226886\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'2.1520202182226886'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.run(\"What is the age of Olivia Wilde's boyfriend raised to the 0.23 power?\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "d3254c6c", + "id": "2f0852ff", "metadata": {}, "outputs": [], "source": [] diff --git a/langchain/routing_chains/__init__.py b/langchain/routing_chains/__init__.py index 2a53116a652..d2d3c682c29 100644 --- a/langchain/routing_chains/__init__.py +++ b/langchain/routing_chains/__init__.py @@ -1,9 +1,11 @@ """Routing chains.""" +from langchain.routing_chains.loading import load_routing_chain from langchain.routing_chains.mrkl.base import MRKLChain from langchain.routing_chains.react.base import ReActChain from langchain.routing_chains.router import LLMRouter from langchain.routing_chains.routing_chain import RoutingChain from langchain.routing_chains.self_ask_with_search.base import SelfAskWithSearchChain +from langchain.routing_chains.tools import Tool __all__ = [ "MRKLChain", @@ -11,4 +13,6 @@ __all__ = [ "ReActChain", "LLMRouter", "RoutingChain", + "Tool", + "load_routing_chain", ] diff --git a/langchain/routing_chains/loading.py b/langchain/routing_chains/loading.py new file mode 100644 index 00000000000..35b473c0ded --- /dev/null +++ b/langchain/routing_chains/loading.py @@ -0,0 +1,43 @@ +"""Load routing chains.""" +from typing import Any, List + +from langchain.llms.base import LLM +from langchain.routing_chains.mrkl.base import ZeroShotRouter +from langchain.routing_chains.react.base import ReActDocstoreRouter +from langchain.routing_chains.routing_chain import RoutingChain +from langchain.routing_chains.self_ask_with_search.base import SelfAskWithSearchRouter +from langchain.routing_chains.tools import Tool + +ROUTER_TYPE_TO_CLASS = { + "zero-shot-react-description": ZeroShotRouter, + "react-docstore": ReActDocstoreRouter, + "self-ask-with-search": SelfAskWithSearchRouter, +} + + +def load_routing_chain( + tools: List[Tool], + llm: LLM, + router_type: str = "zero-shot-react-description", + **kwargs: Any, +) -> RoutingChain: + """Load routing chain given tools and LLM. + + Args: + tools: List of tools this routing chain has access to. + llm: Language model to use as the router. + router_type: The router to use. Valid options are: + `zero-shot-react-description`. + **kwargs: Additional key word arguments to pass to the routing chain. + + Returns: + A routing chain. + """ + if router_type not in ROUTER_TYPE_TO_CLASS: + raise ValueError( + f"Got unknown router type: {router_type}. " + f"Valid types are: {ROUTER_TYPE_TO_CLASS.keys()}." + ) + router_cls = ROUTER_TYPE_TO_CLASS[router_type] + router = router_cls.from_llm_and_tools(llm, tools) + return RoutingChain(router=router, tools=tools, **kwargs) diff --git a/langchain/routing_chains/mrkl/base.py b/langchain/routing_chains/mrkl/base.py index cc714280e18..c6929814aa9 100644 --- a/langchain/routing_chains/mrkl/base.py +++ b/langchain/routing_chains/mrkl/base.py @@ -6,7 +6,8 @@ from langchain.llms.base import LLM from langchain.prompts import PromptTemplate from langchain.routing_chains.mrkl.prompt import BASE_TEMPLATE from langchain.routing_chains.router import LLMRouter -from langchain.routing_chains.routing_chain import RoutingChain, ToolConfig +from langchain.routing_chains.routing_chain import RoutingChain +from langchain.routing_chains.tools import Tool FINAL_ANSWER_ACTION = "Final Answer: " @@ -46,7 +47,7 @@ def get_action_and_input(llm_output: str) -> Tuple[str, str]: return action, action_input.strip(" ").strip('"') -class MRKLRouterChain(LLMRouter): +class ZeroShotRouter(LLMRouter): """Router for the MRKL chain.""" @property @@ -59,17 +60,15 @@ class MRKLRouterChain(LLMRouter): """Prefix to append the router call with.""" return "Thought:" - def __init__(self, llm: LLM, chain_configs: List[ChainConfig], **kwargs: Any): - """Initialize with an LLM and the chain configs it has access to.""" - tools = "\n".join( - [f"{c.action_name}: {c.action_description}" for c in chain_configs] - ) - tool_names = ", ".join([chain.action_name for chain in chain_configs]) - template = BASE_TEMPLATE.format(tools=tools, tool_names=tool_names) + @classmethod + def from_llm_and_tools(cls, llm: LLM, tools: List[Tool]) -> "ZeroShotRouter": + """Construct a router from an LLM and tools.""" + tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) + tool_names = ", ".join([tool.name for tool in tools]) + template = BASE_TEMPLATE.format(tools=tool_strings, tool_names=tool_names) prompt = PromptTemplate(template=template, input_variables=["input"]) llm_chain = LLMChain(llm=llm, prompt=prompt) - stops = ["\nObservation"] - super().__init__(llm_chain=llm_chain, stops=stops, **kwargs) + return cls(llm_chain=llm_chain) def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]: return get_action_and_input(text) @@ -128,8 +127,50 @@ class MRKLChain(RoutingChain): ] mrkl = MRKLChain.from_chains(llm, chains) """ - router_chain = MRKLRouterChain(llm, chains) - expert_configs = [ - ToolConfig(tool_name=c.action_name, tool=c.action) for c in chains + tools = [ + Tool(name=c.action_name, func=c.action, description=c.action_description) + for c in chains ] - return cls(router_chain=router_chain, expert_configs=expert_configs, **kwargs) + return cls.from_tools_and_llm(tools, llm, **kwargs) + + @classmethod + def from_tools_and_llm( + cls, tools: List[Tool], llm: LLM, **kwargs: Any + ) -> "MRKLChain": + """User friendly way to initialize the MRKL chain. + + This is intended to be an easy way to get up and running with the + MRKL chain. + + Args: + tools: The tools the MRKL system has access to. + llm: The LLM to use as the router LLM. + **kwargs: parameters to be passed to initialization. + + Returns: + An initialized MRKL chain. + + Example: + .. code-block:: python + + from langchain import LLMMathChain, OpenAI, SerpAPIChain, MRKLChain + from langchain.routing_chains.tools import ToolConfig + llm = OpenAI(temperature=0) + search = SerpAPIChain() + llm_math_chain = LLMMathChain(llm=llm) + tools = [ + ToolConfig( + tool_name = "Search", + tool=search.search, + tool_description="useful for searching" + ), + ToolConfig( + tool_name="Calculator", + tool=llm_math_chain.run, + tool_description="useful for doing math" + ) + ] + mrkl = MRKLChain.from_tools_and_llm(llm, tools) + """ + router = ZeroShotRouter.from_llm_and_tools(llm, tools) + return cls(router=router, tools=tools, **kwargs) diff --git a/langchain/routing_chains/react/base.py b/langchain/routing_chains/react/base.py index 113465d2773..e02a98660f0 100644 --- a/langchain/routing_chains/react/base.py +++ b/langchain/routing_chains/react/base.py @@ -1,6 +1,6 @@ """Chain that implements the ReAct paper from https://arxiv.org/pdf/2210.03629.pdf.""" import re -from typing import Any, Optional, Tuple +from typing import Any, List, Optional, Tuple from pydantic import BaseModel @@ -10,19 +10,28 @@ from langchain.docstore.document import Document from langchain.llms.base import LLM from langchain.routing_chains.react.prompt import PROMPT from langchain.routing_chains.router import LLMRouter -from langchain.routing_chains.routing_chain import RoutingChain, ToolConfig +from langchain.routing_chains.routing_chain import RoutingChain +from langchain.routing_chains.tools import Tool -class ReActRouterChain(LLMRouter, BaseModel): +class ReActDocstoreRouter(LLMRouter, BaseModel): """Router for the ReAct chin.""" i: int = 1 - def __init__(self, llm: LLM, **kwargs: Any): - """Initialize with the language model.""" + @classmethod + def from_llm_and_tools(cls, llm: LLM, tools: List[Tool]) -> "ReActDocstoreRouter": + """Construct a router from an LLM and tools.""" + if len(tools) != 2: + raise ValueError(f"Exactly two tools must be specified, but got {tools}") + tool_names = {tool.name for tool in tools} + if tool_names != {"Lookup", "Search"}: + raise ValueError( + f"Tool names should be Lookup and Search, got {tool_names}" + ) + llm_chain = LLMChain(llm=llm, prompt=PROMPT) - stops = ["\nObservation 1:"] - super().__init__(llm_chain=llm_chain, stops=stops, **kwargs) + return cls(llm_chain=llm_chain) def _fix_text(self, text: str) -> str: return text + f"\nAction {self.i}:" @@ -32,7 +41,6 @@ class ReActRouterChain(LLMRouter, BaseModel): if not text.split("\n")[-1].startswith(action_prefix): return None self.i += 1 - self.stops = [f"\nObservation {self.i}:"] action_block = text.split("\n")[-1] action_str = action_block[len(action_prefix) :] @@ -43,8 +51,8 @@ class ReActRouterChain(LLMRouter, BaseModel): return re_matches.group(1), re_matches.group(2) @property - def finish_action_name(self) -> str: - """Name of the action of when to finish the chain.""" + def finish_tool_name(self) -> str: + """Name of the tool of when to finish the chain.""" return "Finish" @property @@ -52,6 +60,10 @@ class ReActRouterChain(LLMRouter, BaseModel): """Prefix to append the observation with.""" return f"Observation {self.i - 1}: " + @property + def _stop(self) -> List[str]: + return [f"\nObservation {self.i}: "] + @property def router_prefix(self) -> str: """Prefix to append the router call with.""" @@ -95,10 +107,10 @@ class ReActChain(RoutingChain): def __init__(self, llm: LLM, docstore: Docstore, **kwargs: Any): """Initialize with the LLM and a docstore.""" - router = ReActRouterChain(llm) docstore_explorer = DocstoreExplorer(docstore) - tool_configs = [ - ToolConfig(tool_name="Search", tool=docstore_explorer.search), - ToolConfig(tool_name="Lookup", tool=docstore_explorer.lookup), + tools = [ + Tool(name="Search", func=docstore_explorer.search), + Tool(name="Lookup", func=docstore_explorer.lookup), ] - super().__init__(router=router, expert_configs=tool_configs, **kwargs) + router = ReActDocstoreRouter.from_llm_and_tools(llm, tools) + super().__init__(router=router, tools=tools, **kwargs) diff --git a/langchain/routing_chains/router.py b/langchain/routing_chains/router.py index cff56b3d90c..01cdea9dc2b 100644 --- a/langchain/routing_chains/router.py +++ b/langchain/routing_chains/router.py @@ -1,10 +1,12 @@ """Chain that takes in an input and produces an action and action input.""" from abc import ABC, abstractmethod -from typing import NamedTuple, Optional, Tuple +from typing import List, NamedTuple, Optional, Tuple from pydantic import BaseModel from langchain.chains.llm import LLMChain +from langchain.llms.base import LLM +from langchain.routing_chains.tools import Tool class RouterOutput(NamedTuple): @@ -63,6 +65,15 @@ class LLMRouter(Router, BaseModel, ABC): """Fix the text.""" raise ValueError("fix_text not implemented for this router.") + @property + def _stop(self) -> List[str]: + return [f"\n{self.observation_prefix}"] + + @classmethod + @abstractmethod + def from_llm_and_tools(cls, llm: LLM, tools: List[Tool]) -> "Router": + """Construct a router from an LLM and tools.""" + def route(self, text: str) -> RouterOutput: """Given input, decided how to route it. @@ -73,12 +84,12 @@ class LLMRouter(Router, BaseModel, ABC): RouterOutput specifying what tool to use. """ input_key = self.llm_chain.input_keys[0] - inputs = {input_key: text, "stop": [self.observation_prefix]} + inputs = {input_key: text, "stop": self._stop} full_output = self.llm_chain.predict(**inputs) parsed_output = self._extract_tool_and_input(full_output) while parsed_output is None: full_output = self._fix_text(full_output) - inputs = {input_key: text + full_output, "stop": [self.observation_prefix]} + inputs = {input_key: text + full_output, "stop": self._stop} output = self.llm_chain.predict(**inputs) full_output += output parsed_output = self._extract_tool_and_input(full_output) diff --git a/langchain/routing_chains/routing_chain.py b/langchain/routing_chains/routing_chain.py index b49eeeeba3c..b9ff4389635 100644 --- a/langchain/routing_chains/routing_chain.py +++ b/langchain/routing_chains/routing_chain.py @@ -1,18 +1,12 @@ """Router-Expert framework.""" -from typing import Callable, Dict, List, NamedTuple +from typing import Dict, List from pydantic import BaseModel, Extra from langchain.chains.base import Chain from langchain.input import ChainedInput, get_color_mapping from langchain.routing_chains.router import Router - - -class ToolConfig(NamedTuple): - """Configuration for tools.""" - - tool_name: str - tool: Callable[[str], str] +from langchain.routing_chains.tools import Tool class RoutingChain(Chain, BaseModel): @@ -20,8 +14,8 @@ class RoutingChain(Chain, BaseModel): router: Router """Router to use.""" - tool_configs: List[ToolConfig] - """Tool configs this chain has access to.""" + tools: List[Tool] + """Tools this chain has access to.""" input_key: str = "question" #: :meta private: output_key: str = "answer" #: :meta private: @@ -49,7 +43,7 @@ class RoutingChain(Chain, BaseModel): def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: # Construct a mapping of tool name to tool for easy lookup - name_to_tool_map = {tc.tool_name: tc.tool for tc in self.tool_configs} + name_to_tool_map = {tool.name: tool.func for tool in self.tools} # Construct the initial string to pass into the router. This is made up # of the user input, the special starter string, and then the router prefix. # The starter string is a special string that may be used by a router to @@ -64,7 +58,7 @@ class RoutingChain(Chain, BaseModel): chained_input = ChainedInput(starter_string, verbose=self.verbose) # We construct a mapping from each tool to a color, used for logging. color_mapping = get_color_mapping( - [c.tool_name for c in self.tool_configs], excluded_colors=["green"] + [tool.name for tool in self.tools], excluded_colors=["green"] ) # We now enter the router loop (until it returns something). while True: diff --git a/langchain/routing_chains/self_ask_with_search/base.py b/langchain/routing_chains/self_ask_with_search/base.py index ab4f8543eac..81af5543c64 100644 --- a/langchain/routing_chains/self_ask_with_search/base.py +++ b/langchain/routing_chains/self_ask_with_search/base.py @@ -1,21 +1,33 @@ """Chain that does self ask with search.""" -from typing import Any, Tuple +from typing import Any, List, Tuple from langchain.chains.llm import LLMChain from langchain.chains.serpapi import SerpAPIChain from langchain.llms.base import LLM from langchain.routing_chains.router import LLMRouter -from langchain.routing_chains.routing_chain import RoutingChain, ToolConfig +from langchain.routing_chains.routing_chain import RoutingChain from langchain.routing_chains.self_ask_with_search.prompt import PROMPT +from langchain.routing_chains.tools import Tool class SelfAskWithSearchRouter(LLMRouter): """Router for the self-ask-with-search paper.""" - def __init__(self, llm: LLM, **kwargs: Any): - """Initialize with an LLM.""" + @classmethod + def from_llm_and_tools( + cls, llm: LLM, tools: List[Tool] + ) -> "SelfAskWithSearchRouter": + """Construct a router from an LLM and tools.""" + if len(tools) != 1: + raise ValueError(f"Exactly one tool must be specified, but got {tools}") + tool_names = {tool.name for tool in tools} + if tool_names != {"Intermediate Answer"}: + raise ValueError( + f"Tool name should be Intermediate Answer, got {tool_names}" + ) + llm_chain = LLMChain(llm=llm, prompt=PROMPT) - super().__init__(llm_chain=llm_chain, **kwargs) + return cls(llm_chain=llm_chain, tools=tools) def _extract_tool_and_input(self, text: str) -> Tuple[str, str]: followup = "Follow up:" @@ -71,8 +83,6 @@ class SelfAskWithSearchChain(RoutingChain): def __init__(self, llm: LLM, search_chain: SerpAPIChain, **kwargs: Any): """Initialize with just an LLM and a search chain.""" - intermediate = "\nIntermediate answer:" - router = SelfAskWithSearchRouter(llm, stops=[intermediate]) - search_tool = ToolConfig(tool_name="Intermediate Answer", tool=search_chain.run) - expert_configs = [search_tool] - super().__init__(router=router, expert_configs=expert_configs, **kwargs) + search_tool = Tool(name="Intermediate Answer", func=search_chain.run) + router = SelfAskWithSearchRouter.from_llm_and_tools(llm, [search_tool]) + super().__init__(router=router, tools=[search_tool], **kwargs) diff --git a/langchain/routing_chains/tools.py b/langchain/routing_chains/tools.py new file mode 100644 index 00000000000..2b04db21d9f --- /dev/null +++ b/langchain/routing_chains/tools.py @@ -0,0 +1,10 @@ +"""Interface for tools.""" +from typing import Callable, NamedTuple, Optional + + +class Tool(NamedTuple): + """Interface for tools.""" + + name: str + func: Callable[[str], str] + description: Optional[str] = None diff --git a/tests/unit_tests/routing_chains/test_mrkl.py b/tests/unit_tests/routing_chains/test_mrkl.py index 29d833f1201..66781088438 100644 --- a/tests/unit_tests/routing_chains/test_mrkl.py +++ b/tests/unit_tests/routing_chains/test_mrkl.py @@ -3,12 +3,9 @@ import pytest from langchain.prompts import PromptTemplate -from langchain.routing_chains.mrkl.base import ( - ChainConfig, - MRKLRouterChain, - get_action_and_input, -) +from langchain.routing_chains.mrkl.base import ZeroShotRouter, get_action_and_input from langchain.routing_chains.mrkl.prompt import BASE_TEMPLATE +from langchain.routing_chains.tools import Tool from tests.unit_tests.llms.fake_llm import FakeLLM @@ -56,14 +53,10 @@ def test_bad_action_line() -> None: def test_from_chains() -> None: """Test initializing from chains.""" chain_configs = [ - ChainConfig( - action_name="foo", action=lambda x: "foo", action_description="foobar1" - ), - ChainConfig( - action_name="bar", action=lambda x: "bar", action_description="foobar2" - ), + Tool(name="foo", func=lambda x: "foo", description="foobar1"), + Tool(name="bar", func=lambda x: "bar", description="foobar2"), ] - router_chain = MRKLRouterChain(FakeLLM(), chain_configs) + router_chain = ZeroShotRouter.from_llm_and_tools(FakeLLM(), chain_configs) expected_tools_prompt = "foo: foobar1\nbar: foobar2" expected_tool_names = "foo, bar" expected_template = BASE_TEMPLATE.format( diff --git a/tests/unit_tests/routing_chains/test_react.py b/tests/unit_tests/routing_chains/test_react.py index 9b078f90ed5..9e82650040c 100644 --- a/tests/unit_tests/routing_chains/test_react.py +++ b/tests/unit_tests/routing_chains/test_react.py @@ -8,7 +8,7 @@ from langchain.docstore.base import Docstore from langchain.docstore.document import Document from langchain.llms.base import LLM from langchain.prompts.prompt import PromptTemplate -from langchain.routing_chains.react.base import ReActChain, ReActRouterChain +from langchain.routing_chains.react.base import ReActChain, ReActDocstoreRouter _PAGE_CONTENT = """This is a page about LangChain. @@ -52,7 +52,7 @@ def test_predict_until_observation_normal() -> None: """Test predict_until_observation when observation is made normally.""" outputs = ["foo\nAction 1: search[foo]"] fake_llm = FakeListLLM(outputs) - router_chain = ReActRouterChain(llm=fake_llm) + router_chain = ReActDocstoreRouter(llm=fake_llm) output = router_chain.route("") assert output.log == outputs[0] assert output.tool == "search" @@ -63,7 +63,7 @@ def test_predict_until_observation_repeat() -> None: """Test when no action is generated initially.""" outputs = ["foo", " search[foo]"] fake_llm = FakeListLLM(outputs) - router_chain = ReActRouterChain(llm=fake_llm) + router_chain = ReActDocstoreRouter(llm=fake_llm) output = router_chain.route("") assert output.log == "foo\nAction 1: search[foo]" assert output.tool == "search"