mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-03 10:12:33 +00:00
move search to not be a chain (#226)
This commit is contained in:
parent
b19a73be26
commit
ca2394028f
@ -48,7 +48,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.agents import ZeroShotAgent, Tool\n",
|
"from langchain.agents import ZeroShotAgent, Tool\n",
|
||||||
"from langchain import OpenAI, SerpAPIChain, LLMChain"
|
"from langchain import OpenAI, SerpAPIWrapper, LLMChain"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -58,7 +58,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"search = SerpAPIChain()\n",
|
"search = SerpAPIWrapper()\n",
|
||||||
"tools = [\n",
|
"tools = [\n",
|
||||||
" Tool(\n",
|
" Tool(\n",
|
||||||
" name = \"Search\",\n",
|
" name = \"Search\",\n",
|
||||||
|
@ -26,7 +26,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain import LLMMathChain, OpenAI, SerpAPIChain, SQLDatabase, SQLDatabaseChain\n",
|
"from langchain import LLMMathChain, OpenAI, SerpAPIWrapper, SQLDatabase, SQLDatabaseChain\n",
|
||||||
"from langchain.agents import initialize_agent, Tool"
|
"from langchain.agents import initialize_agent, Tool"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -38,7 +38,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"llm = OpenAI(temperature=0)\n",
|
"llm = OpenAI(temperature=0)\n",
|
||||||
"search = SerpAPIChain()\n",
|
"search = SerpAPIWrapper()\n",
|
||||||
"llm_math_chain = LLMMathChain(llm=llm, verbose=True)\n",
|
"llm_math_chain = LLMMathChain(llm=llm, verbose=True)\n",
|
||||||
"db = SQLDatabase.from_uri(\"sqlite:///../../../notebooks/Chinook.db\")\n",
|
"db = SQLDatabase.from_uri(\"sqlite:///../../../notebooks/Chinook.db\")\n",
|
||||||
"db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)\n",
|
"db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)\n",
|
||||||
|
@ -45,11 +45,11 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain import OpenAI, SerpAPIChain\n",
|
"from langchain import OpenAI, SerpAPIWrapper\n",
|
||||||
"from langchain.agents import initialize_agent, Tool\n",
|
"from langchain.agents import initialize_agent, Tool\n",
|
||||||
"\n",
|
"\n",
|
||||||
"llm = OpenAI(temperature=0)\n",
|
"llm = OpenAI(temperature=0)\n",
|
||||||
"search = SerpAPIChain()\n",
|
"search = SerpAPIWrapper()\n",
|
||||||
"tools = [\n",
|
"tools = [\n",
|
||||||
" Tool(\n",
|
" Tool(\n",
|
||||||
" name=\"Intermediate Answer\",\n",
|
" name=\"Intermediate Answer\",\n",
|
||||||
|
@ -29,7 +29,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"from langchain.agents import ZeroShotAgent, Tool\n",
|
"from langchain.agents import ZeroShotAgent, Tool\n",
|
||||||
"from langchain.chains.conversation.memory import ConversationBufferMemory\n",
|
"from langchain.chains.conversation.memory import ConversationBufferMemory\n",
|
||||||
"from langchain import OpenAI, SerpAPIChain, LLMChain"
|
"from langchain import OpenAI, SerpAPIWrapper, LLMChain"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -39,7 +39,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"search = SerpAPIChain()\n",
|
"search = SerpAPIWrapper()\n",
|
||||||
"tools = [\n",
|
"tools = [\n",
|
||||||
" Tool(\n",
|
" Tool(\n",
|
||||||
" name = \"Search\",\n",
|
" name = \"Search\",\n",
|
||||||
|
@ -135,14 +135,14 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain import SelfAskWithSearchChain, SerpAPIChain\n",
|
"from langchain import SelfAskWithSearchChain, SerpAPIWrapper\n",
|
||||||
"\n",
|
"\n",
|
||||||
"open_ai_llm = OpenAI(temperature=0)\n",
|
"open_ai_llm = OpenAI(temperature=0)\n",
|
||||||
"search = SerpAPIChain()\n",
|
"search = SerpAPIWrapper()\n",
|
||||||
"self_ask_with_search_openai = SelfAskWithSearchChain(llm=open_ai_llm, search_chain=search, verbose=True)\n",
|
"self_ask_with_search_openai = SelfAskWithSearchChain(llm=open_ai_llm, search_chain=search, verbose=True)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"cohere_llm = Cohere(temperature=0, model=\"command-xlarge-20221108\")\n",
|
"cohere_llm = Cohere(temperature=0, model=\"command-xlarge-20221108\")\n",
|
||||||
"search = SerpAPIChain()\n",
|
"search = SerpAPIWrapper()\n",
|
||||||
"self_ask_with_search_cohere = SelfAskWithSearchChain(llm=cohere_llm, search_chain=search, verbose=True)"
|
"self_ask_with_search_cohere = SelfAskWithSearchChain(llm=cohere_llm, search_chain=search, verbose=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -77,9 +77,9 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Load the tool configs that are needed.\n",
|
"# Load the tool configs that are needed.\n",
|
||||||
"from langchain import LLMMathChain, SerpAPIChain\n",
|
"from langchain import LLMMathChain, SerpAPIWrapper\n",
|
||||||
"llm = OpenAI(temperature=0)\n",
|
"llm = OpenAI(temperature=0)\n",
|
||||||
"search = SerpAPIChain()\n",
|
"search = SerpAPIWrapper()\n",
|
||||||
"llm_math_chain = LLMMathChain(llm=llm, verbose=True)\n",
|
"llm_math_chain = LLMMathChain(llm=llm, verbose=True)\n",
|
||||||
"tools = [\n",
|
"tools = [\n",
|
||||||
" Tool(\n",
|
" Tool(\n",
|
||||||
|
@ -12,7 +12,6 @@ from langchain.chains import (
|
|||||||
LLMMathChain,
|
LLMMathChain,
|
||||||
PALChain,
|
PALChain,
|
||||||
PythonChain,
|
PythonChain,
|
||||||
SerpAPIChain,
|
|
||||||
SQLDatabaseChain,
|
SQLDatabaseChain,
|
||||||
VectorDBQA,
|
VectorDBQA,
|
||||||
)
|
)
|
||||||
@ -24,6 +23,7 @@ from langchain.prompts import (
|
|||||||
Prompt,
|
Prompt,
|
||||||
PromptTemplate,
|
PromptTemplate,
|
||||||
)
|
)
|
||||||
|
from langchain.serpapi import SerpAPIWrapper
|
||||||
from langchain.sql_database import SQLDatabase
|
from langchain.sql_database import SQLDatabase
|
||||||
from langchain.vectorstores import FAISS, ElasticVectorSearch
|
from langchain.vectorstores import FAISS, ElasticVectorSearch
|
||||||
|
|
||||||
@ -32,7 +32,8 @@ __all__ = [
|
|||||||
"LLMMathChain",
|
"LLMMathChain",
|
||||||
"PythonChain",
|
"PythonChain",
|
||||||
"SelfAskWithSearchChain",
|
"SelfAskWithSearchChain",
|
||||||
"SerpAPIChain",
|
"SerpAPIWrapper",
|
||||||
|
"SerpAPIWrapper",
|
||||||
"Cohere",
|
"Cohere",
|
||||||
"OpenAI",
|
"OpenAI",
|
||||||
"BasePromptTemplate",
|
"BasePromptTemplate",
|
||||||
|
@ -131,10 +131,10 @@ class MRKLChain(ZeroShotAgent):
|
|||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain import LLMMathChain, OpenAI, SerpAPIChain, MRKLChain
|
from langchain import LLMMathChain, OpenAI, SerpAPIWrapper, MRKLChain
|
||||||
from langchain.chains.mrkl.base import ChainConfig
|
from langchain.chains.mrkl.base import ChainConfig
|
||||||
llm = OpenAI(temperature=0)
|
llm = OpenAI(temperature=0)
|
||||||
search = SerpAPIChain()
|
search = SerpAPIWrapper()
|
||||||
llm_math_chain = LLMMathChain(llm=llm)
|
llm_math_chain = LLMMathChain(llm=llm)
|
||||||
chains = [
|
chains = [
|
||||||
ChainConfig(
|
ChainConfig(
|
||||||
|
@ -5,9 +5,9 @@ from langchain.agents.agent import Agent
|
|||||||
from langchain.agents.self_ask_with_search.prompt import PROMPT
|
from langchain.agents.self_ask_with_search.prompt import PROMPT
|
||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.serpapi import SerpAPIChain
|
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
|
from langchain.serpapi import SerpAPIWrapper
|
||||||
|
|
||||||
|
|
||||||
class SelfAskWithSearchAgent(Agent):
|
class SelfAskWithSearchAgent(Agent):
|
||||||
@ -73,12 +73,12 @@ class SelfAskWithSearchChain(SelfAskWithSearchAgent):
|
|||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain import SelfAskWithSearchChain, OpenAI, SerpAPIChain
|
from langchain import SelfAskWithSearchChain, OpenAI, SerpAPIWrapper
|
||||||
search_chain = SerpAPIChain()
|
search_chain = SerpAPIWrapper()
|
||||||
self_ask = SelfAskWithSearchChain(llm=OpenAI(), search_chain=search_chain)
|
self_ask = SelfAskWithSearchChain(llm=OpenAI(), search_chain=search_chain)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, llm: LLM, search_chain: SerpAPIChain, **kwargs: Any):
|
def __init__(self, llm: LLM, search_chain: SerpAPIWrapper, **kwargs: Any):
|
||||||
"""Initialize with just an LLM and a search chain."""
|
"""Initialize with just an LLM and a search chain."""
|
||||||
search_tool = Tool(name="Intermediate Answer", func=search_chain.run)
|
search_tool = Tool(name="Intermediate Answer", func=search_chain.run)
|
||||||
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
|
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
|
||||||
|
@ -5,7 +5,6 @@ from langchain.chains.llm_math.base import LLMMathChain
|
|||||||
from langchain.chains.pal.base import PALChain
|
from langchain.chains.pal.base import PALChain
|
||||||
from langchain.chains.python import PythonChain
|
from langchain.chains.python import PythonChain
|
||||||
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
|
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
|
||||||
from langchain.chains.serpapi import SerpAPIChain
|
|
||||||
from langchain.chains.sql_database.base import SQLDatabaseChain
|
from langchain.chains.sql_database.base import SQLDatabaseChain
|
||||||
from langchain.chains.vector_db_qa.base import VectorDBQA
|
from langchain.chains.vector_db_qa.base import VectorDBQA
|
||||||
|
|
||||||
@ -13,7 +12,6 @@ __all__ = [
|
|||||||
"LLMChain",
|
"LLMChain",
|
||||||
"LLMMathChain",
|
"LLMMathChain",
|
||||||
"PythonChain",
|
"PythonChain",
|
||||||
"SerpAPIChain",
|
|
||||||
"SQLDatabaseChain",
|
"SQLDatabaseChain",
|
||||||
"VectorDBQA",
|
"VectorDBQA",
|
||||||
"SequentialChain",
|
"SequentialChain",
|
||||||
|
@ -4,11 +4,10 @@ Heavily borrowed from https://github.com/ofirpress/self-ask
|
|||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
@ -26,8 +25,8 @@ class HiddenPrints:
|
|||||||
sys.stdout = self._original_stdout
|
sys.stdout = self._original_stdout
|
||||||
|
|
||||||
|
|
||||||
class SerpAPIChain(Chain, BaseModel):
|
class SerpAPIWrapper(BaseModel):
|
||||||
"""Chain that calls SerpAPI.
|
"""Wrapper around SerpAPI.
|
||||||
|
|
||||||
To use, you should have the ``google-search-results`` python package installed,
|
To use, you should have the ``google-search-results`` python package installed,
|
||||||
and the environment variable ``SERPAPI_API_KEY`` set with your API key, or pass
|
and the environment variable ``SERPAPI_API_KEY`` set with your API key, or pass
|
||||||
@ -36,13 +35,11 @@ class SerpAPIChain(Chain, BaseModel):
|
|||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain import SerpAPIChain
|
from langchain import SerpAPIWrapper
|
||||||
serpapi = SerpAPIChain()
|
serpapi = SerpAPIWrapper()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
search_engine: Any #: :meta private:
|
search_engine: Any #: :meta private:
|
||||||
input_key: str = "search_query" #: :meta private:
|
|
||||||
output_key: str = "search_result" #: :meta private:
|
|
||||||
|
|
||||||
serpapi_api_key: Optional[str] = None
|
serpapi_api_key: Optional[str] = None
|
||||||
|
|
||||||
@ -51,22 +48,6 @@ class SerpAPIChain(Chain, BaseModel):
|
|||||||
|
|
||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
"""Return the singular input key.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.input_key]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
"""Return the singular output key.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.output_key]
|
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
@ -85,11 +66,12 @@ class SerpAPIChain(Chain, BaseModel):
|
|||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
def run(self, query: str) -> str:
|
||||||
|
"""Run query through SerpAPI and parse result."""
|
||||||
params = {
|
params = {
|
||||||
"api_key": self.serpapi_api_key,
|
"api_key": self.serpapi_api_key,
|
||||||
"engine": "google",
|
"engine": "google",
|
||||||
"q": inputs[self.input_key],
|
"q": query,
|
||||||
"google_domain": "google.com",
|
"google_domain": "google.com",
|
||||||
"gl": "us",
|
"gl": "us",
|
||||||
"hl": "en",
|
"hl": "en",
|
||||||
@ -112,4 +94,9 @@ class SerpAPIChain(Chain, BaseModel):
|
|||||||
toret = res["organic_results"][0]["snippet"]
|
toret = res["organic_results"][0]["snippet"]
|
||||||
else:
|
else:
|
||||||
toret = "No good search result found"
|
toret = "No good search result found"
|
||||||
return {self.output_key: toret}
|
return toret
|
||||||
|
|
||||||
|
|
||||||
|
# For backwards compatability
|
||||||
|
|
||||||
|
SerpAPIWrapper = SerpAPIWrapper
|
@ -1,7 +1,7 @@
|
|||||||
"""Integration test for self ask with search."""
|
"""Integration test for self ask with search."""
|
||||||
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain
|
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain
|
||||||
from langchain.chains.serpapi import SerpAPIChain
|
|
||||||
from langchain.llms.openai import OpenAI
|
from langchain.llms.openai import OpenAI
|
||||||
|
from langchain.serpapi import SerpAPIWrapper
|
||||||
|
|
||||||
|
|
||||||
def test_self_ask_with_search() -> None:
|
def test_self_ask_with_search() -> None:
|
||||||
@ -9,7 +9,7 @@ def test_self_ask_with_search() -> None:
|
|||||||
question = "What is the hometown of the reigning men's U.S. Open champion?"
|
question = "What is the hometown of the reigning men's U.S. Open champion?"
|
||||||
chain = SelfAskWithSearchChain(
|
chain = SelfAskWithSearchChain(
|
||||||
llm=OpenAI(temperature=0),
|
llm=OpenAI(temperature=0),
|
||||||
search_chain=SerpAPIChain(),
|
search_chain=SerpAPIWrapper(),
|
||||||
input_key="q",
|
input_key="q",
|
||||||
output_key="a",
|
output_key="a",
|
||||||
)
|
)
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
"""Integration test for SerpAPI."""
|
"""Integration test for SerpAPI."""
|
||||||
from langchain.chains.serpapi import SerpAPIChain
|
from langchain.serpapi import SerpAPIWrapper
|
||||||
|
|
||||||
|
|
||||||
def test_call() -> None:
|
def test_call() -> None:
|
||||||
"""Test that call gives the correct answer."""
|
"""Test that call gives the correct answer."""
|
||||||
chain = SerpAPIChain()
|
chain = SerpAPIWrapper()
|
||||||
output = chain.run("What was Obama's first name?")
|
output = chain.run("What was Obama's first name?")
|
||||||
assert output == "Barack Hussein Obama II"
|
assert output == "Barack Hussein Obama II"
|
Loading…
Reference in New Issue
Block a user