mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-16 08:06:14 +00:00
cr
This commit is contained in:
parent
4ccb9b684a
commit
505cb2eb62
@ -27,7 +27,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain import LLMMathChain, OpenAI, SerpAPIChain, MRKLChain, SQLDatabase, SQLDatabaseChain\n",
|
||||
"from langchain.chains.mrkl.base import ChainConfig"
|
||||
"from langchain.routing_chains.mrkl.base import ChainConfig"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -167,7 +167,7 @@
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"What albums by Alanis Morissette are in the FooBar database?\n",
|
||||
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Album.Title FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId WHERE Artist.Name = 'Alanis Morissette'\u001b[0m\n",
|
||||
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Album.Title FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId WHERE Artist.Name = \"Alanis Morissette\"\u001b[0m\n",
|
||||
"SQLResult: \u001b[33;1m\u001b[1;3m[('Jagged Little Pill',)]\u001b[0m\n",
|
||||
"Answer:\u001b[32;1m\u001b[1;3m Jagged Little Pill\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
|
@ -41,10 +41,12 @@
|
||||
"with.\n",
|
||||
"Action 1: Search[David Chanoff]\u001b[0m\n",
|
||||
"Observation 1: \u001b[36;1m\u001b[1;3mDavid Chanoff is a noted author of non-fiction work. His work has typically involved collaborations with the principal protagonist of the work concerned. His collaborators have included; Augustus A. White, Joycelyn Elders, Đoàn Văn Toại, William J. Crowe, Ariel Sharon, Kenneth Good and Felix Zandman. He has also written about a wide range of subjects including literary history, education and foreign for The Washington Post, The New Republic and The New York Times Magazine. He has published more than twelve books.\u001b[0m\n",
|
||||
"Thought 2:\u001b[32;1m\u001b[1;3m The U.S. Navy admiral David Chanoff collaborated with is William J. Crowe.\n",
|
||||
"Thought 2:\u001b[32;1m\u001b[1;3m The U.S. Navy admiral David Chanoff collaborated with is William J. Crowe. I\n",
|
||||
"need to search him next.\n",
|
||||
"Action 2: Search[William J. Crowe]\u001b[0m\n",
|
||||
"Observation 2: \u001b[36;1m\u001b[1;3mWilliam James Crowe Jr. (January 2, 1925 – October 18, 2007) was a United States Navy admiral and diplomat who served as the 11th chairman of the Joint Chiefs of Staff under Presidents Ronald Reagan and George H. W. Bush, and as the ambassador to the United Kingdom and Chair of the Intelligence Oversight Board under President Bill Clinton.\u001b[0m\n",
|
||||
"Thought 3:\u001b[32;1m\u001b[1;3m William J. Crowe served as the ambassador to the United Kingdom under President Bill Clinton.\n",
|
||||
"Thought 3:\u001b[32;1m\u001b[1;3m William J. Crowe served as the ambassador to the United Kingdom under\n",
|
||||
"President Bill Clinton. So the answer is Bill Clinton.\n",
|
||||
"Action 3: Finish[Bill Clinton]\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
@ -68,7 +70,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3cb9d77c",
|
||||
"id": "4ff64e81",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
|
@ -13,7 +13,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4b21ae68",
|
||||
"id": "3c6226b9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Concepts\n",
|
||||
@ -26,20 +26,19 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c9677868",
|
||||
"id": "05d4b21e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Tools\n",
|
||||
"When constructing your own Routing Chain, you will need to provide it with a list of tools that it can use. This is done with a list of ToolConfigs. The ToolConfig is used not only to create the Routing Chain, but is also sometimes used to create the router itself (often, the router logic depends on the tools available). \n",
|
||||
"When constructing your own Routing Chain, you will need to provide it with a list of tools that it can use. This is done with a list of Tools. The Tools are used not only to create the Routing Chain, but is also sometimes used to create the router itself (often, the router logic depends on the tools available). \n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"class ToolConfig(NamedTuple):\n",
|
||||
" \"\"\"Configuration for tools.\"\"\"\n",
|
||||
"class Tool(NamedTuple):\n",
|
||||
" \"\"\"Interface for tools.\"\"\"\n",
|
||||
"\n",
|
||||
" tool_name: str\n",
|
||||
" tool: Callable[[str], str]\n",
|
||||
" # Needed to construct some routers.\n",
|
||||
" tool_description: Optional[str] = None\n",
|
||||
" name: str\n",
|
||||
" func: Callable[[str], str]\n",
|
||||
" description: Optional[str] = None\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"The two required components of a ToolConfig are the name and then the tool itself. A tool description is optional, as it is needed for some routers but not all."
|
||||
@ -50,22 +49,123 @@
|
||||
"id": "2558a02d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Loading the chains\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"from langchain.routing_chains import load_chain\n",
|
||||
"\n",
|
||||
"tools: List[ToolConfig] = [...]\n",
|
||||
"llm: LLM = OpenAI(temperature=0)\n",
|
||||
"router_type: str = \"zero-shot\"\n",
|
||||
"chain = load_chain(tools, llm, router_type, verbose=True)\n",
|
||||
"```"
|
||||
"## Loading the chains\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "36ed392e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Import things that are needed generically\n",
|
||||
"from langchain.routing_chains import load_routing_chain, Tool\n",
|
||||
"from langchain.llms import OpenAI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "56ff7670",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load the tool configs that are needed.\n",
|
||||
"from langchain import LLMMathChain, SerpAPIChain\n",
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"search = SerpAPIChain()\n",
|
||||
"llm_math_chain = LLMMathChain(llm=llm, verbose=True)\n",
|
||||
"tools = [\n",
|
||||
" Tool(\n",
|
||||
" name = \"Search\",\n",
|
||||
" func=search.run,\n",
|
||||
" description=\"useful for when you need to answer questions about current events\"\n",
|
||||
" ),\n",
|
||||
" Tool(\n",
|
||||
" name=\"Calculator\",\n",
|
||||
" func=llm_math_chain.run,\n",
|
||||
" description=\"useful for when you need to answer questions about math\"\n",
|
||||
" )\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "5b93047d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Construct the routing chain. We will use the default router type here.\n",
|
||||
"# See documentation for a full list of options.\n",
|
||||
"router_llm = OpenAI(temperature=0)\n",
|
||||
"chain = load_routing_chain(tools, router_llm, verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "6f96a891",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"What is the age of Olivia Wilde's boyfriend raised to the 0.23 power?\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I need to find the age of Olivia Wilde's boyfriend\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: \"Olivia Wilde's boyfriend\"\u001b[0m\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3mOlivia Wilde started dating Harry Styles after ending her years-long engagement to Jason Sudeikis — see their relationship timeline.\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I need to find the age of Harry Styles\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: \"Harry Styles age\"\u001b[0m\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3m28 years\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate 28 to the 0.23 power\n",
|
||||
"Action: Calculator\n",
|
||||
"Action Input: 28^0.23\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"28^0.23\u001b[32;1m\u001b[1;3m\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"print(28**0.23)\n",
|
||||
"```\n",
|
||||
"\u001b[0m\n",
|
||||
"Answer: \u001b[33;1m\u001b[1;3m2.1520202182226886\n",
|
||||
"\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\n",
|
||||
"Observation: \u001b[33;1m\u001b[1;3mAnswer: 2.1520202182226886\n",
|
||||
"\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
||||
"Final Answer: 2.1520202182226886\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'2.1520202182226886'"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.run(\"What is the age of Olivia Wilde's boyfriend raised to the 0.23 power?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d3254c6c",
|
||||
"id": "2f0852ff",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
|
@ -1,9 +1,11 @@
|
||||
"""Routing chains."""
|
||||
from langchain.routing_chains.loading import load_routing_chain
|
||||
from langchain.routing_chains.mrkl.base import MRKLChain
|
||||
from langchain.routing_chains.react.base import ReActChain
|
||||
from langchain.routing_chains.router import LLMRouter
|
||||
from langchain.routing_chains.routing_chain import RoutingChain
|
||||
from langchain.routing_chains.self_ask_with_search.base import SelfAskWithSearchChain
|
||||
from langchain.routing_chains.tools import Tool
|
||||
|
||||
__all__ = [
|
||||
"MRKLChain",
|
||||
@ -11,4 +13,6 @@ __all__ = [
|
||||
"ReActChain",
|
||||
"LLMRouter",
|
||||
"RoutingChain",
|
||||
"Tool",
|
||||
"load_routing_chain",
|
||||
]
|
||||
|
43
langchain/routing_chains/loading.py
Normal file
43
langchain/routing_chains/loading.py
Normal file
@ -0,0 +1,43 @@
|
||||
"""Load routing chains."""
|
||||
from typing import Any, List
|
||||
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.routing_chains.mrkl.base import ZeroShotRouter
|
||||
from langchain.routing_chains.react.base import ReActDocstoreRouter
|
||||
from langchain.routing_chains.routing_chain import RoutingChain
|
||||
from langchain.routing_chains.self_ask_with_search.base import SelfAskWithSearchRouter
|
||||
from langchain.routing_chains.tools import Tool
|
||||
|
||||
ROUTER_TYPE_TO_CLASS = {
|
||||
"zero-shot-react-description": ZeroShotRouter,
|
||||
"react-docstore": ReActDocstoreRouter,
|
||||
"self-ask-with-search": SelfAskWithSearchRouter,
|
||||
}
|
||||
|
||||
|
||||
def load_routing_chain(
|
||||
tools: List[Tool],
|
||||
llm: LLM,
|
||||
router_type: str = "zero-shot-react-description",
|
||||
**kwargs: Any,
|
||||
) -> RoutingChain:
|
||||
"""Load routing chain given tools and LLM.
|
||||
|
||||
Args:
|
||||
tools: List of tools this routing chain has access to.
|
||||
llm: Language model to use as the router.
|
||||
router_type: The router to use. Valid options are:
|
||||
`zero-shot-react-description`.
|
||||
**kwargs: Additional key word arguments to pass to the routing chain.
|
||||
|
||||
Returns:
|
||||
A routing chain.
|
||||
"""
|
||||
if router_type not in ROUTER_TYPE_TO_CLASS:
|
||||
raise ValueError(
|
||||
f"Got unknown router type: {router_type}. "
|
||||
f"Valid types are: {ROUTER_TYPE_TO_CLASS.keys()}."
|
||||
)
|
||||
router_cls = ROUTER_TYPE_TO_CLASS[router_type]
|
||||
router = router_cls.from_llm_and_tools(llm, tools)
|
||||
return RoutingChain(router=router, tools=tools, **kwargs)
|
@ -6,7 +6,8 @@ from langchain.llms.base import LLM
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.routing_chains.mrkl.prompt import BASE_TEMPLATE
|
||||
from langchain.routing_chains.router import LLMRouter
|
||||
from langchain.routing_chains.routing_chain import RoutingChain, ToolConfig
|
||||
from langchain.routing_chains.routing_chain import RoutingChain
|
||||
from langchain.routing_chains.tools import Tool
|
||||
|
||||
FINAL_ANSWER_ACTION = "Final Answer: "
|
||||
|
||||
@ -46,7 +47,7 @@ def get_action_and_input(llm_output: str) -> Tuple[str, str]:
|
||||
return action, action_input.strip(" ").strip('"')
|
||||
|
||||
|
||||
class MRKLRouterChain(LLMRouter):
|
||||
class ZeroShotRouter(LLMRouter):
|
||||
"""Router for the MRKL chain."""
|
||||
|
||||
@property
|
||||
@ -59,17 +60,15 @@ class MRKLRouterChain(LLMRouter):
|
||||
"""Prefix to append the router call with."""
|
||||
return "Thought:"
|
||||
|
||||
def __init__(self, llm: LLM, chain_configs: List[ChainConfig], **kwargs: Any):
|
||||
"""Initialize with an LLM and the chain configs it has access to."""
|
||||
tools = "\n".join(
|
||||
[f"{c.action_name}: {c.action_description}" for c in chain_configs]
|
||||
)
|
||||
tool_names = ", ".join([chain.action_name for chain in chain_configs])
|
||||
template = BASE_TEMPLATE.format(tools=tools, tool_names=tool_names)
|
||||
@classmethod
|
||||
def from_llm_and_tools(cls, llm: LLM, tools: List[Tool]) -> "ZeroShotRouter":
|
||||
"""Construct a router from an LLM and tools."""
|
||||
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
|
||||
tool_names = ", ".join([tool.name for tool in tools])
|
||||
template = BASE_TEMPLATE.format(tools=tool_strings, tool_names=tool_names)
|
||||
prompt = PromptTemplate(template=template, input_variables=["input"])
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
stops = ["\nObservation"]
|
||||
super().__init__(llm_chain=llm_chain, stops=stops, **kwargs)
|
||||
return cls(llm_chain=llm_chain)
|
||||
|
||||
def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]:
|
||||
return get_action_and_input(text)
|
||||
@ -128,8 +127,50 @@ class MRKLChain(RoutingChain):
|
||||
]
|
||||
mrkl = MRKLChain.from_chains(llm, chains)
|
||||
"""
|
||||
router_chain = MRKLRouterChain(llm, chains)
|
||||
expert_configs = [
|
||||
ToolConfig(tool_name=c.action_name, tool=c.action) for c in chains
|
||||
tools = [
|
||||
Tool(name=c.action_name, func=c.action, description=c.action_description)
|
||||
for c in chains
|
||||
]
|
||||
return cls(router_chain=router_chain, expert_configs=expert_configs, **kwargs)
|
||||
return cls.from_tools_and_llm(tools, llm, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_tools_and_llm(
|
||||
cls, tools: List[Tool], llm: LLM, **kwargs: Any
|
||||
) -> "MRKLChain":
|
||||
"""User friendly way to initialize the MRKL chain.
|
||||
|
||||
This is intended to be an easy way to get up and running with the
|
||||
MRKL chain.
|
||||
|
||||
Args:
|
||||
tools: The tools the MRKL system has access to.
|
||||
llm: The LLM to use as the router LLM.
|
||||
**kwargs: parameters to be passed to initialization.
|
||||
|
||||
Returns:
|
||||
An initialized MRKL chain.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import LLMMathChain, OpenAI, SerpAPIChain, MRKLChain
|
||||
from langchain.routing_chains.tools import ToolConfig
|
||||
llm = OpenAI(temperature=0)
|
||||
search = SerpAPIChain()
|
||||
llm_math_chain = LLMMathChain(llm=llm)
|
||||
tools = [
|
||||
ToolConfig(
|
||||
tool_name = "Search",
|
||||
tool=search.search,
|
||||
tool_description="useful for searching"
|
||||
),
|
||||
ToolConfig(
|
||||
tool_name="Calculator",
|
||||
tool=llm_math_chain.run,
|
||||
tool_description="useful for doing math"
|
||||
)
|
||||
]
|
||||
mrkl = MRKLChain.from_tools_and_llm(llm, tools)
|
||||
"""
|
||||
router = ZeroShotRouter.from_llm_and_tools(llm, tools)
|
||||
return cls(router=router, tools=tools, **kwargs)
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Chain that implements the ReAct paper from https://arxiv.org/pdf/2210.03629.pdf."""
|
||||
import re
|
||||
from typing import Any, Optional, Tuple
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -10,19 +10,28 @@ from langchain.docstore.document import Document
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.routing_chains.react.prompt import PROMPT
|
||||
from langchain.routing_chains.router import LLMRouter
|
||||
from langchain.routing_chains.routing_chain import RoutingChain, ToolConfig
|
||||
from langchain.routing_chains.routing_chain import RoutingChain
|
||||
from langchain.routing_chains.tools import Tool
|
||||
|
||||
|
||||
class ReActRouterChain(LLMRouter, BaseModel):
|
||||
class ReActDocstoreRouter(LLMRouter, BaseModel):
|
||||
"""Router for the ReAct chin."""
|
||||
|
||||
i: int = 1
|
||||
|
||||
def __init__(self, llm: LLM, **kwargs: Any):
|
||||
"""Initialize with the language model."""
|
||||
@classmethod
|
||||
def from_llm_and_tools(cls, llm: LLM, tools: List[Tool]) -> "ReActDocstoreRouter":
|
||||
"""Construct a router from an LLM and tools."""
|
||||
if len(tools) != 2:
|
||||
raise ValueError(f"Exactly two tools must be specified, but got {tools}")
|
||||
tool_names = {tool.name for tool in tools}
|
||||
if tool_names != {"Lookup", "Search"}:
|
||||
raise ValueError(
|
||||
f"Tool names should be Lookup and Search, got {tool_names}"
|
||||
)
|
||||
|
||||
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
|
||||
stops = ["\nObservation 1:"]
|
||||
super().__init__(llm_chain=llm_chain, stops=stops, **kwargs)
|
||||
return cls(llm_chain=llm_chain)
|
||||
|
||||
def _fix_text(self, text: str) -> str:
|
||||
return text + f"\nAction {self.i}:"
|
||||
@ -32,7 +41,6 @@ class ReActRouterChain(LLMRouter, BaseModel):
|
||||
if not text.split("\n")[-1].startswith(action_prefix):
|
||||
return None
|
||||
self.i += 1
|
||||
self.stops = [f"\nObservation {self.i}:"]
|
||||
action_block = text.split("\n")[-1]
|
||||
|
||||
action_str = action_block[len(action_prefix) :]
|
||||
@ -43,8 +51,8 @@ class ReActRouterChain(LLMRouter, BaseModel):
|
||||
return re_matches.group(1), re_matches.group(2)
|
||||
|
||||
@property
|
||||
def finish_action_name(self) -> str:
|
||||
"""Name of the action of when to finish the chain."""
|
||||
def finish_tool_name(self) -> str:
|
||||
"""Name of the tool of when to finish the chain."""
|
||||
return "Finish"
|
||||
|
||||
@property
|
||||
@ -52,6 +60,10 @@ class ReActRouterChain(LLMRouter, BaseModel):
|
||||
"""Prefix to append the observation with."""
|
||||
return f"Observation {self.i - 1}: "
|
||||
|
||||
@property
|
||||
def _stop(self) -> List[str]:
|
||||
return [f"\nObservation {self.i}: "]
|
||||
|
||||
@property
|
||||
def router_prefix(self) -> str:
|
||||
"""Prefix to append the router call with."""
|
||||
@ -95,10 +107,10 @@ class ReActChain(RoutingChain):
|
||||
|
||||
def __init__(self, llm: LLM, docstore: Docstore, **kwargs: Any):
|
||||
"""Initialize with the LLM and a docstore."""
|
||||
router = ReActRouterChain(llm)
|
||||
docstore_explorer = DocstoreExplorer(docstore)
|
||||
tool_configs = [
|
||||
ToolConfig(tool_name="Search", tool=docstore_explorer.search),
|
||||
ToolConfig(tool_name="Lookup", tool=docstore_explorer.lookup),
|
||||
tools = [
|
||||
Tool(name="Search", func=docstore_explorer.search),
|
||||
Tool(name="Lookup", func=docstore_explorer.lookup),
|
||||
]
|
||||
super().__init__(router=router, expert_configs=tool_configs, **kwargs)
|
||||
router = ReActDocstoreRouter.from_llm_and_tools(llm, tools)
|
||||
super().__init__(router=router, tools=tools, **kwargs)
|
||||
|
@ -1,10 +1,12 @@
|
||||
"""Chain that takes in an input and produces an action and action input."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import NamedTuple, Optional, Tuple
|
||||
from typing import List, NamedTuple, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.routing_chains.tools import Tool
|
||||
|
||||
|
||||
class RouterOutput(NamedTuple):
|
||||
@ -63,6 +65,15 @@ class LLMRouter(Router, BaseModel, ABC):
|
||||
"""Fix the text."""
|
||||
raise ValueError("fix_text not implemented for this router.")
|
||||
|
||||
@property
|
||||
def _stop(self) -> List[str]:
|
||||
return [f"\n{self.observation_prefix}"]
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_llm_and_tools(cls, llm: LLM, tools: List[Tool]) -> "Router":
|
||||
"""Construct a router from an LLM and tools."""
|
||||
|
||||
def route(self, text: str) -> RouterOutput:
|
||||
"""Given input, decided how to route it.
|
||||
|
||||
@ -73,12 +84,12 @@ class LLMRouter(Router, BaseModel, ABC):
|
||||
RouterOutput specifying what tool to use.
|
||||
"""
|
||||
input_key = self.llm_chain.input_keys[0]
|
||||
inputs = {input_key: text, "stop": [self.observation_prefix]}
|
||||
inputs = {input_key: text, "stop": self._stop}
|
||||
full_output = self.llm_chain.predict(**inputs)
|
||||
parsed_output = self._extract_tool_and_input(full_output)
|
||||
while parsed_output is None:
|
||||
full_output = self._fix_text(full_output)
|
||||
inputs = {input_key: text + full_output, "stop": [self.observation_prefix]}
|
||||
inputs = {input_key: text + full_output, "stop": self._stop}
|
||||
output = self.llm_chain.predict(**inputs)
|
||||
full_output += output
|
||||
parsed_output = self._extract_tool_and_input(full_output)
|
||||
|
@ -1,18 +1,12 @@
|
||||
"""Router-Expert framework."""
|
||||
from typing import Callable, Dict, List, NamedTuple
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.input import ChainedInput, get_color_mapping
|
||||
from langchain.routing_chains.router import Router
|
||||
|
||||
|
||||
class ToolConfig(NamedTuple):
|
||||
"""Configuration for tools."""
|
||||
|
||||
tool_name: str
|
||||
tool: Callable[[str], str]
|
||||
from langchain.routing_chains.tools import Tool
|
||||
|
||||
|
||||
class RoutingChain(Chain, BaseModel):
|
||||
@ -20,8 +14,8 @@ class RoutingChain(Chain, BaseModel):
|
||||
|
||||
router: Router
|
||||
"""Router to use."""
|
||||
tool_configs: List[ToolConfig]
|
||||
"""Tool configs this chain has access to."""
|
||||
tools: List[Tool]
|
||||
"""Tools this chain has access to."""
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
|
||||
@ -49,7 +43,7 @@ class RoutingChain(Chain, BaseModel):
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
# Construct a mapping of tool name to tool for easy lookup
|
||||
name_to_tool_map = {tc.tool_name: tc.tool for tc in self.tool_configs}
|
||||
name_to_tool_map = {tool.name: tool.func for tool in self.tools}
|
||||
# Construct the initial string to pass into the router. This is made up
|
||||
# of the user input, the special starter string, and then the router prefix.
|
||||
# The starter string is a special string that may be used by a router to
|
||||
@ -64,7 +58,7 @@ class RoutingChain(Chain, BaseModel):
|
||||
chained_input = ChainedInput(starter_string, verbose=self.verbose)
|
||||
# We construct a mapping from each tool to a color, used for logging.
|
||||
color_mapping = get_color_mapping(
|
||||
[c.tool_name for c in self.tool_configs], excluded_colors=["green"]
|
||||
[tool.name for tool in self.tools], excluded_colors=["green"]
|
||||
)
|
||||
# We now enter the router loop (until it returns something).
|
||||
while True:
|
||||
|
@ -1,21 +1,33 @@
|
||||
"""Chain that does self ask with search."""
|
||||
from typing import Any, Tuple
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.serpapi import SerpAPIChain
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.routing_chains.router import LLMRouter
|
||||
from langchain.routing_chains.routing_chain import RoutingChain, ToolConfig
|
||||
from langchain.routing_chains.routing_chain import RoutingChain
|
||||
from langchain.routing_chains.self_ask_with_search.prompt import PROMPT
|
||||
from langchain.routing_chains.tools import Tool
|
||||
|
||||
|
||||
class SelfAskWithSearchRouter(LLMRouter):
|
||||
"""Router for the self-ask-with-search paper."""
|
||||
|
||||
def __init__(self, llm: LLM, **kwargs: Any):
|
||||
"""Initialize with an LLM."""
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls, llm: LLM, tools: List[Tool]
|
||||
) -> "SelfAskWithSearchRouter":
|
||||
"""Construct a router from an LLM and tools."""
|
||||
if len(tools) != 1:
|
||||
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
|
||||
tool_names = {tool.name for tool in tools}
|
||||
if tool_names != {"Intermediate Answer"}:
|
||||
raise ValueError(
|
||||
f"Tool name should be Intermediate Answer, got {tool_names}"
|
||||
)
|
||||
|
||||
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
|
||||
super().__init__(llm_chain=llm_chain, **kwargs)
|
||||
return cls(llm_chain=llm_chain, tools=tools)
|
||||
|
||||
def _extract_tool_and_input(self, text: str) -> Tuple[str, str]:
|
||||
followup = "Follow up:"
|
||||
@ -71,8 +83,6 @@ class SelfAskWithSearchChain(RoutingChain):
|
||||
|
||||
def __init__(self, llm: LLM, search_chain: SerpAPIChain, **kwargs: Any):
|
||||
"""Initialize with just an LLM and a search chain."""
|
||||
intermediate = "\nIntermediate answer:"
|
||||
router = SelfAskWithSearchRouter(llm, stops=[intermediate])
|
||||
search_tool = ToolConfig(tool_name="Intermediate Answer", tool=search_chain.run)
|
||||
expert_configs = [search_tool]
|
||||
super().__init__(router=router, expert_configs=expert_configs, **kwargs)
|
||||
search_tool = Tool(name="Intermediate Answer", func=search_chain.run)
|
||||
router = SelfAskWithSearchRouter.from_llm_and_tools(llm, [search_tool])
|
||||
super().__init__(router=router, tools=[search_tool], **kwargs)
|
||||
|
10
langchain/routing_chains/tools.py
Normal file
10
langchain/routing_chains/tools.py
Normal file
@ -0,0 +1,10 @@
|
||||
"""Interface for tools."""
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
|
||||
|
||||
class Tool(NamedTuple):
|
||||
"""Interface for tools."""
|
||||
|
||||
name: str
|
||||
func: Callable[[str], str]
|
||||
description: Optional[str] = None
|
@ -3,12 +3,9 @@
|
||||
import pytest
|
||||
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.routing_chains.mrkl.base import (
|
||||
ChainConfig,
|
||||
MRKLRouterChain,
|
||||
get_action_and_input,
|
||||
)
|
||||
from langchain.routing_chains.mrkl.base import ZeroShotRouter, get_action_and_input
|
||||
from langchain.routing_chains.mrkl.prompt import BASE_TEMPLATE
|
||||
from langchain.routing_chains.tools import Tool
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
@ -56,14 +53,10 @@ def test_bad_action_line() -> None:
|
||||
def test_from_chains() -> None:
|
||||
"""Test initializing from chains."""
|
||||
chain_configs = [
|
||||
ChainConfig(
|
||||
action_name="foo", action=lambda x: "foo", action_description="foobar1"
|
||||
),
|
||||
ChainConfig(
|
||||
action_name="bar", action=lambda x: "bar", action_description="foobar2"
|
||||
),
|
||||
Tool(name="foo", func=lambda x: "foo", description="foobar1"),
|
||||
Tool(name="bar", func=lambda x: "bar", description="foobar2"),
|
||||
]
|
||||
router_chain = MRKLRouterChain(FakeLLM(), chain_configs)
|
||||
router_chain = ZeroShotRouter.from_llm_and_tools(FakeLLM(), chain_configs)
|
||||
expected_tools_prompt = "foo: foobar1\nbar: foobar2"
|
||||
expected_tool_names = "foo, bar"
|
||||
expected_template = BASE_TEMPLATE.format(
|
||||
|
@ -8,7 +8,7 @@ from langchain.docstore.base import Docstore
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.routing_chains.react.base import ReActChain, ReActRouterChain
|
||||
from langchain.routing_chains.react.base import ReActChain, ReActDocstoreRouter
|
||||
|
||||
_PAGE_CONTENT = """This is a page about LangChain.
|
||||
|
||||
@ -52,7 +52,7 @@ def test_predict_until_observation_normal() -> None:
|
||||
"""Test predict_until_observation when observation is made normally."""
|
||||
outputs = ["foo\nAction 1: search[foo]"]
|
||||
fake_llm = FakeListLLM(outputs)
|
||||
router_chain = ReActRouterChain(llm=fake_llm)
|
||||
router_chain = ReActDocstoreRouter(llm=fake_llm)
|
||||
output = router_chain.route("")
|
||||
assert output.log == outputs[0]
|
||||
assert output.tool == "search"
|
||||
@ -63,7 +63,7 @@ def test_predict_until_observation_repeat() -> None:
|
||||
"""Test when no action is generated initially."""
|
||||
outputs = ["foo", " search[foo]"]
|
||||
fake_llm = FakeListLLM(outputs)
|
||||
router_chain = ReActRouterChain(llm=fake_llm)
|
||||
router_chain = ReActDocstoreRouter(llm=fake_llm)
|
||||
output = router_chain.route("")
|
||||
assert output.log == "foo\nAction 1: search[foo]"
|
||||
assert output.tool == "search"
|
||||
|
Loading…
Reference in New Issue
Block a user