mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 01:23:07 +00:00
reorg smart chains
This commit is contained in:
parent
2a84d3d5ca
commit
68eaf4e5ee
@ -20,9 +20,6 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
@ -33,8 +30,6 @@
|
||||
"\u001b[32;1m\u001b[1;3mFollow up: Where is Carlos Alcaraz from?\u001b[0m\n",
|
||||
"Intermediate answer: \u001b[36;1m\u001b[1;3mEl Palmar, Spain\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3mSo the final answer is: El Palmar, Spain\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
|
@ -8,10 +8,7 @@ with open(Path(__file__).absolute().parents[0] / "VERSION") as _f:
|
||||
from langchain.chains import (
|
||||
LLMChain,
|
||||
LLMMathChain,
|
||||
MRKLChain,
|
||||
PythonChain,
|
||||
ReActChain,
|
||||
SelfAskWithSearchChain,
|
||||
SerpAPIChain,
|
||||
SQLDatabaseChain,
|
||||
VectorDBQA,
|
||||
@ -19,6 +16,7 @@ from langchain.chains import (
|
||||
from langchain.docstore import Wikipedia
|
||||
from langchain.llms import Cohere, HuggingFaceHub, OpenAI
|
||||
from langchain.prompts import BasePromptTemplate, PromptTemplate
|
||||
from langchain.smart_chains import MRKLChain, ReActChain, SelfAskWithSearchChain
|
||||
from langchain.sql_database import SQLDatabase
|
||||
from langchain.vectorstores import FAISS, ElasticVectorSearch
|
||||
|
||||
|
@ -1,10 +1,7 @@
|
||||
"""Chains are easily reusable components which can be linked together."""
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_math.base import LLMMathChain
|
||||
from langchain.chains.mrkl.base import MRKLChain
|
||||
from langchain.chains.python import PythonChain
|
||||
from langchain.chains.react.base import ReActChain
|
||||
from langchain.chains.self_ask_with_search.base import SelfAskWithSearchChain
|
||||
from langchain.chains.serpapi import SerpAPIChain
|
||||
from langchain.chains.sql_database.base import SQLDatabaseChain
|
||||
from langchain.chains.vector_db_qa.base import VectorDBQA
|
||||
@ -13,10 +10,7 @@ __all__ = [
|
||||
"LLMChain",
|
||||
"LLMMathChain",
|
||||
"PythonChain",
|
||||
"SelfAskWithSearchChain",
|
||||
"SerpAPIChain",
|
||||
"ReActChain",
|
||||
"SQLDatabaseChain",
|
||||
"MRKLChain",
|
||||
"VectorDBQA",
|
||||
]
|
||||
|
6
langchain/smart_chains/__init__.py
Normal file
6
langchain/smart_chains/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
"""Smart chains."""
|
||||
from langchain.smart_chains.mrkl.base import MRKLChain
|
||||
from langchain.smart_chains.react.base import ReActChain
|
||||
from langchain.smart_chains.self_ask_with_search.base import SelfAskWithSearchChain
|
||||
|
||||
__all__ = ["MRKLChain", "SelfAskWithSearchChain", "ReActChain"]
|
@ -1,16 +1,12 @@
|
||||
"""Attempt to implement MRKL systems as described in arxiv.org/pdf/2205.00445.pdf."""
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple
|
||||
from typing import Any, Callable, List, NamedTuple, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.mrkl.prompt import BASE_TEMPLATE
|
||||
from langchain.chains.router import LLMRouterChain
|
||||
from langchain.input import ChainedInput, get_color_mapping
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.prompts import BasePromptTemplate, PromptTemplate
|
||||
from langchain.chains.router_expert import RouterExpertChain, ExpertConfig
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.smart_chains.mrkl.prompt import BASE_TEMPLATE
|
||||
from langchain.smart_chains.router import LLMRouterChain
|
||||
from langchain.smart_chains.router_expert import ExpertConfig, RouterExpertChain
|
||||
|
||||
FINAL_ANSWER_ACTION = "Final Answer: "
|
||||
|
||||
@ -79,33 +75,20 @@ class MRKLRouterChain(LLMRouterChain):
|
||||
return get_action_and_input(text)
|
||||
|
||||
|
||||
class MRKLChain(Chain, BaseModel):
|
||||
class MRKLChain(RouterExpertChain):
|
||||
"""Chain that implements the MRKL system.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import OpenAI, Prompt, MRKLChain
|
||||
from langchain import OpenAI, MRKLChain
|
||||
from langchain.chains.mrkl.base import ChainConfig
|
||||
llm = OpenAI(temperature=0)
|
||||
prompt = PromptTemplate(...)
|
||||
action_to_chain_map = {...}
|
||||
mrkl = MRKLChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
action_to_chain_map=action_to_chain_map
|
||||
)
|
||||
chains = [...]
|
||||
mrkl = MRKLChain.from_chains(llm=llm, prompt=prompt)
|
||||
"""
|
||||
|
||||
llm: LLM
|
||||
"""LLM wrapper to use as router."""
|
||||
chain_configs: List[ChainConfig]
|
||||
"""Chain configs this chain has access to."""
|
||||
action_to_chain_map: Dict[str, Callable]
|
||||
"""Mapping from action name to chain to execute."""
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
|
||||
@classmethod
|
||||
def from_chains(
|
||||
cls, llm: LLM, chains: List[ChainConfig], **kwargs: Any
|
||||
@ -145,47 +128,8 @@ class MRKLChain(Chain, BaseModel):
|
||||
]
|
||||
mrkl = MRKLChain.from_chains(llm, chains)
|
||||
"""
|
||||
action_to_chain_map = {chain.action_name: chain.action for chain in chains}
|
||||
return cls(
|
||||
llm=llm,
|
||||
chain_configs=chains,
|
||||
action_to_chain_map=action_to_chain_map,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
router_chain = MRKLRouterChain(self.llm, self.chain_configs)
|
||||
question = inputs[self.input_key]
|
||||
router_chain = MRKLRouterChain(llm, chains)
|
||||
expert_configs = [
|
||||
ExpertConfig(expert_name=c.action_name, expert=c.action)
|
||||
for c in self.chain_configs
|
||||
ExpertConfig(expert_name=c.action_name, expert=c.action) for c in chains
|
||||
]
|
||||
chain = RouterExpertChain(
|
||||
router_chain=router_chain,
|
||||
expert_configs=expert_configs,
|
||||
verbose=self.verbose
|
||||
)
|
||||
output = chain.run(question)
|
||||
return {self.output_key: output}
|
||||
return cls(router_chain=router_chain, expert_configs=expert_configs, **kwargs)
|
@ -1,18 +1,16 @@
|
||||
"""Chain that implements the ReAct paper from https://arxiv.org/pdf/2210.03629.pdf."""
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.react.prompt import PROMPT
|
||||
from langchain.chains.router import LLMRouterChain
|
||||
from langchain.docstore.base import Docstore
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.input import ChainedInput
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.chains.router_expert import RouterExpertChain, ExpertConfig
|
||||
from langchain.smart_chains.react.prompt import PROMPT
|
||||
from langchain.smart_chains.router import LLMRouterChain
|
||||
from langchain.smart_chains.router_expert import ExpertConfig, RouterExpertChain
|
||||
|
||||
|
||||
class ReActRouterChain(LLMRouterChain, BaseModel):
|
||||
@ -46,7 +44,7 @@ class ReActRouterChain(LLMRouterChain, BaseModel):
|
||||
|
||||
@property
|
||||
def finish_action_name(self) -> str:
|
||||
"""The action name of when to finish the chain."""
|
||||
"""Name of the action of when to finish the chain."""
|
||||
return "Finish"
|
||||
|
||||
@property
|
||||
@ -61,12 +59,15 @@ class ReActRouterChain(LLMRouterChain, BaseModel):
|
||||
|
||||
|
||||
class DocstoreExplorer:
|
||||
"""Class to assist with exploration of a document store."""
|
||||
|
||||
def __init__(self, docstore: Docstore):
|
||||
self.docstore=docstore
|
||||
self.document = None
|
||||
"""Initialize with a docstore, and set initial document to None."""
|
||||
self.docstore = docstore
|
||||
self.document: Optional[Document] = None
|
||||
|
||||
def search(self, term: str):
|
||||
def search(self, term: str) -> str:
|
||||
"""Search for a term in the docstore, and if found save."""
|
||||
result = self.docstore.search(term)
|
||||
if isinstance(result, Document):
|
||||
self.document = result
|
||||
@ -75,13 +76,14 @@ class DocstoreExplorer:
|
||||
self.document = None
|
||||
return result
|
||||
|
||||
def lookup(self, term: str):
|
||||
def lookup(self, term: str) -> str:
|
||||
"""Lookup a term in document (if saved)."""
|
||||
if self.document is None:
|
||||
raise ValueError("Cannot lookup without a successful search first")
|
||||
return self.document.lookup(term)
|
||||
|
||||
|
||||
class ReActChain(Chain, BaseModel):
|
||||
class ReActChain(RouterExpertChain):
|
||||
"""Chain that implements the ReAct paper.
|
||||
|
||||
Example:
|
||||
@ -91,47 +93,14 @@ class ReActChain(Chain, BaseModel):
|
||||
react = ReAct(llm=OpenAI())
|
||||
"""
|
||||
|
||||
llm: LLM
|
||||
"""LLM wrapper to use."""
|
||||
docstore: Docstore
|
||||
"""Docstore to use."""
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
question = inputs[self.input_key]
|
||||
router_chain = ReActRouterChain(self.llm)
|
||||
docstore_explorer = DocstoreExplorer(self.docstore)
|
||||
def __init__(self, llm: LLM, docstore: Docstore, **kwargs: Any):
|
||||
"""Initialize with the LLM and a docstore."""
|
||||
router_chain = ReActRouterChain(llm)
|
||||
docstore_explorer = DocstoreExplorer(docstore)
|
||||
expert_configs = [
|
||||
ExpertConfig(expert_name="Search", expert=docstore_explorer.search),
|
||||
ExpertConfig(expert_name="Lookup", expert=docstore_explorer.lookup)
|
||||
ExpertConfig(expert_name="Lookup", expert=docstore_explorer.lookup),
|
||||
]
|
||||
chain = RouterExpertChain(
|
||||
router_chain=router_chain,
|
||||
expert_configs=expert_configs,
|
||||
verbose=self.verbose
|
||||
super().__init__(
|
||||
router_chain=router_chain, expert_configs=expert_configs, **kwargs
|
||||
)
|
||||
output = chain.run(question)
|
||||
return {self.output_key: output}
|
@ -48,7 +48,7 @@ class RouterChain(Chain, BaseModel, ABC):
|
||||
|
||||
@property
|
||||
def finish_action_name(self) -> str:
|
||||
"""The action name of when to finish the chain."""
|
||||
"""Name of the action of when to finish the chain."""
|
||||
return "Final Answer"
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
@ -1,21 +1,15 @@
|
||||
"""Router-Expert framework."""
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple
|
||||
from typing import Callable, Dict, List, NamedTuple
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.mrkl.prompt import BASE_TEMPLATE
|
||||
from langchain.chains.router import LLMRouterChain
|
||||
from langchain.input import ChainedInput, get_color_mapping
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.prompts import BasePromptTemplate, PromptTemplate
|
||||
from langchain.chains.router import RouterChain
|
||||
|
||||
|
||||
from langchain.smart_chains.router import RouterChain
|
||||
|
||||
|
||||
class ExpertConfig(NamedTuple):
|
||||
"""Configuration for experts."""
|
||||
|
||||
expert_name: str
|
||||
expert: Callable[[str], str]
|
||||
@ -57,8 +51,14 @@ class RouterExpertChain(Chain, BaseModel):
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
action_to_chain_map = {e.expert_name: e.expert for e in self.expert_configs}
|
||||
starter_string = (
|
||||
inputs[self.input_key]
|
||||
+ self.starter_string
|
||||
+ self.router_chain.router_prefix
|
||||
)
|
||||
chained_input = ChainedInput(
|
||||
f"{inputs[self.input_key]}{self.starter_string}{self.router_chain.router_prefix}", verbose=self.verbose
|
||||
starter_string,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
color_mapping = get_color_mapping(
|
||||
[c.expert_name for c in self.expert_configs], excluded_colors=["green"]
|
||||
@ -74,4 +74,4 @@ class RouterExpertChain(Chain, BaseModel):
|
||||
ca = chain(action_input)
|
||||
chained_input.add(f"\n{self.router_chain.observation_prefix}")
|
||||
chained_input.add(ca, color=color_mapping[action])
|
||||
chained_input.add(f"\n{self.router_chain.router_prefix}")
|
||||
chained_input.add(f"\n{self.router_chain.router_prefix}")
|
@ -1,16 +1,12 @@
|
||||
"""Chain that does self ask with search."""
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Tuple
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.router import LLMRouterChain
|
||||
from langchain.chains.self_ask_with_search.prompt import PROMPT
|
||||
from langchain.chains.serpapi import SerpAPIChain
|
||||
from langchain.input import ChainedInput
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.chains.router_expert import RouterExpertChain, ExpertConfig
|
||||
from langchain.smart_chains.router import LLMRouterChain
|
||||
from langchain.smart_chains.router_expert import ExpertConfig, RouterExpertChain
|
||||
from langchain.smart_chains.self_ask_with_search.prompt import PROMPT
|
||||
|
||||
|
||||
class SelfAskWithSearchRouter(LLMRouterChain):
|
||||
@ -32,7 +28,7 @@ class SelfAskWithSearchRouter(LLMRouterChain):
|
||||
finish_string = "So the final answer is: "
|
||||
if finish_string not in last_line:
|
||||
raise ValueError("We should probably never get here")
|
||||
return "Final Answer", text[len(finish_string):]
|
||||
return "Final Answer", text[len(finish_string) :]
|
||||
|
||||
if ":" not in last_line:
|
||||
after_colon = last_line
|
||||
@ -57,7 +53,7 @@ class SelfAskWithSearchRouter(LLMRouterChain):
|
||||
return ""
|
||||
|
||||
|
||||
class SelfAskWithSearchChain(Chain, BaseModel):
|
||||
class SelfAskWithSearchChain(RouterExpertChain):
|
||||
"""Chain that does self ask with search.
|
||||
|
||||
Example:
|
||||
@ -68,39 +64,16 @@ class SelfAskWithSearchChain(Chain, BaseModel):
|
||||
self_ask = SelfAskWithSearchChain(llm=OpenAI(), search_chain=search_chain)
|
||||
"""
|
||||
|
||||
llm: LLM
|
||||
"""LLM wrapper to use."""
|
||||
search_chain: SerpAPIChain
|
||||
"""Search chain to use."""
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
def __init__(self, llm: LLM, search_chain: SerpAPIChain, **kwargs: Any):
|
||||
"""Initialize with just an LLM and a search chain."""
|
||||
intermediate = "\nIntermediate answer:"
|
||||
router = SelfAskWithSearchRouter(self.llm, stops=[intermediate])
|
||||
expert_configs = [ExpertConfig(expert_name="Intermediate Answer", expert=self.search_chain.run)]
|
||||
chain = RouterExpertChain(router_chain=router, expert_configs=expert_configs, verbose=self.verbose, starter_string="\nAre follow up questions needed here:")
|
||||
output = chain.run(inputs[self.input_key])
|
||||
return {self.output_key: output}
|
||||
router = SelfAskWithSearchRouter(llm, stops=[intermediate])
|
||||
expert_configs = [
|
||||
ExpertConfig(expert_name="Intermediate Answer", expert=search_chain.run)
|
||||
]
|
||||
super().__init__(
|
||||
router_chain=router,
|
||||
expert_configs=expert_configs,
|
||||
starter_string="\nAre follow up questions needed here:",
|
||||
**kwargs
|
||||
)
|
@ -1,8 +1,8 @@
|
||||
"""Integration test for self ask with search."""
|
||||
|
||||
from langchain.chains.react.base import ReActChain
|
||||
from langchain.docstore.wikipedia import Wikipedia
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.smart_chains.react.base import ReActChain
|
||||
|
||||
|
||||
def test_react() -> None:
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Integration test for self ask with search."""
|
||||
from langchain.chains.self_ask_with_search.base import SelfAskWithSearchChain
|
||||
from langchain.chains.serpapi import SerpAPIChain
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.smart_chains.self_ask_with_search.base import SelfAskWithSearchChain
|
||||
|
||||
|
||||
def test_self_ask_with_search() -> None:
|
||||
|
1
tests/unit_tests/smart_chains/__init__.py
Normal file
1
tests/unit_tests/smart_chains/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Test smart chain functionality."""
|
@ -2,13 +2,13 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.mrkl.base import (
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.smart_chains.mrkl.base import (
|
||||
ChainConfig,
|
||||
MRKLRouterChain,
|
||||
get_action_and_input,
|
||||
)
|
||||
from langchain.chains.mrkl.prompt import BASE_TEMPLATE
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.smart_chains.mrkl.prompt import BASE_TEMPLATE
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
@ -4,11 +4,11 @@ from typing import Any, List, Mapping, Optional, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.react.base import ReActChain, ReActRouterChain
|
||||
from langchain.docstore.base import Docstore
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.smart_chains.react.base import ReActChain, ReActRouterChain
|
||||
|
||||
_PAGE_CONTENT = """This is a page about LangChain.
|
||||
|
Loading…
Reference in New Issue
Block a user