mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-29 15:59:48 +00:00
move search to not be a chain (#226)
This commit is contained in:
parent
b19a73be26
commit
ca2394028f
@ -48,7 +48,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.agents import ZeroShotAgent, Tool\n",
|
||||
"from langchain import OpenAI, SerpAPIChain, LLMChain"
|
||||
"from langchain import OpenAI, SerpAPIWrapper, LLMChain"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -58,7 +58,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"search = SerpAPIChain()\n",
|
||||
"search = SerpAPIWrapper()\n",
|
||||
"tools = [\n",
|
||||
" Tool(\n",
|
||||
" name = \"Search\",\n",
|
||||
|
@ -26,7 +26,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain import LLMMathChain, OpenAI, SerpAPIChain, SQLDatabase, SQLDatabaseChain\n",
|
||||
"from langchain import LLMMathChain, OpenAI, SerpAPIWrapper, SQLDatabase, SQLDatabaseChain\n",
|
||||
"from langchain.agents import initialize_agent, Tool"
|
||||
]
|
||||
},
|
||||
@ -38,7 +38,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"search = SerpAPIChain()\n",
|
||||
"search = SerpAPIWrapper()\n",
|
||||
"llm_math_chain = LLMMathChain(llm=llm, verbose=True)\n",
|
||||
"db = SQLDatabase.from_uri(\"sqlite:///../../../notebooks/Chinook.db\")\n",
|
||||
"db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)\n",
|
||||
|
@ -45,11 +45,11 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain import OpenAI, SerpAPIChain\n",
|
||||
"from langchain import OpenAI, SerpAPIWrapper\n",
|
||||
"from langchain.agents import initialize_agent, Tool\n",
|
||||
"\n",
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"search = SerpAPIChain()\n",
|
||||
"search = SerpAPIWrapper()\n",
|
||||
"tools = [\n",
|
||||
" Tool(\n",
|
||||
" name=\"Intermediate Answer\",\n",
|
||||
|
@ -29,7 +29,7 @@
|
||||
"source": [
|
||||
"from langchain.agents import ZeroShotAgent, Tool\n",
|
||||
"from langchain.chains.conversation.memory import ConversationBufferMemory\n",
|
||||
"from langchain import OpenAI, SerpAPIChain, LLMChain"
|
||||
"from langchain import OpenAI, SerpAPIWrapper, LLMChain"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -39,7 +39,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"search = SerpAPIChain()\n",
|
||||
"search = SerpAPIWrapper()\n",
|
||||
"tools = [\n",
|
||||
" Tool(\n",
|
||||
" name = \"Search\",\n",
|
||||
|
@ -135,14 +135,14 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain import SelfAskWithSearchChain, SerpAPIChain\n",
|
||||
"from langchain import SelfAskWithSearchChain, SerpAPIWrapper\n",
|
||||
"\n",
|
||||
"open_ai_llm = OpenAI(temperature=0)\n",
|
||||
"search = SerpAPIChain()\n",
|
||||
"search = SerpAPIWrapper()\n",
|
||||
"self_ask_with_search_openai = SelfAskWithSearchChain(llm=open_ai_llm, search_chain=search, verbose=True)\n",
|
||||
"\n",
|
||||
"cohere_llm = Cohere(temperature=0, model=\"command-xlarge-20221108\")\n",
|
||||
"search = SerpAPIChain()\n",
|
||||
"search = SerpAPIWrapper()\n",
|
||||
"self_ask_with_search_cohere = SelfAskWithSearchChain(llm=cohere_llm, search_chain=search, verbose=True)"
|
||||
]
|
||||
},
|
||||
|
@ -77,9 +77,9 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load the tool configs that are needed.\n",
|
||||
"from langchain import LLMMathChain, SerpAPIChain\n",
|
||||
"from langchain import LLMMathChain, SerpAPIWrapper\n",
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"search = SerpAPIChain()\n",
|
||||
"search = SerpAPIWrapper()\n",
|
||||
"llm_math_chain = LLMMathChain(llm=llm, verbose=True)\n",
|
||||
"tools = [\n",
|
||||
" Tool(\n",
|
||||
|
@ -12,7 +12,6 @@ from langchain.chains import (
|
||||
LLMMathChain,
|
||||
PALChain,
|
||||
PythonChain,
|
||||
SerpAPIChain,
|
||||
SQLDatabaseChain,
|
||||
VectorDBQA,
|
||||
)
|
||||
@ -24,6 +23,7 @@ from langchain.prompts import (
|
||||
Prompt,
|
||||
PromptTemplate,
|
||||
)
|
||||
from langchain.serpapi import SerpAPIWrapper
|
||||
from langchain.sql_database import SQLDatabase
|
||||
from langchain.vectorstores import FAISS, ElasticVectorSearch
|
||||
|
||||
@ -32,7 +32,8 @@ __all__ = [
|
||||
"LLMMathChain",
|
||||
"PythonChain",
|
||||
"SelfAskWithSearchChain",
|
||||
"SerpAPIChain",
|
||||
"SerpAPIWrapper",
|
||||
"SerpAPIWrapper",
|
||||
"Cohere",
|
||||
"OpenAI",
|
||||
"BasePromptTemplate",
|
||||
|
@ -131,10 +131,10 @@ class MRKLChain(ZeroShotAgent):
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import LLMMathChain, OpenAI, SerpAPIChain, MRKLChain
|
||||
from langchain import LLMMathChain, OpenAI, SerpAPIWrapper, MRKLChain
|
||||
from langchain.chains.mrkl.base import ChainConfig
|
||||
llm = OpenAI(temperature=0)
|
||||
search = SerpAPIChain()
|
||||
search = SerpAPIWrapper()
|
||||
llm_math_chain = LLMMathChain(llm=llm)
|
||||
chains = [
|
||||
ChainConfig(
|
||||
|
@ -5,9 +5,9 @@ from langchain.agents.agent import Agent
|
||||
from langchain.agents.self_ask_with_search.prompt import PROMPT
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.serpapi import SerpAPIChain
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.serpapi import SerpAPIWrapper
|
||||
|
||||
|
||||
class SelfAskWithSearchAgent(Agent):
|
||||
@ -73,12 +73,12 @@ class SelfAskWithSearchChain(SelfAskWithSearchAgent):
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import SelfAskWithSearchChain, OpenAI, SerpAPIChain
|
||||
search_chain = SerpAPIChain()
|
||||
from langchain import SelfAskWithSearchChain, OpenAI, SerpAPIWrapper
|
||||
search_chain = SerpAPIWrapper()
|
||||
self_ask = SelfAskWithSearchChain(llm=OpenAI(), search_chain=search_chain)
|
||||
"""
|
||||
|
||||
def __init__(self, llm: LLM, search_chain: SerpAPIChain, **kwargs: Any):
|
||||
def __init__(self, llm: LLM, search_chain: SerpAPIWrapper, **kwargs: Any):
|
||||
"""Initialize with just an LLM and a search chain."""
|
||||
search_tool = Tool(name="Intermediate Answer", func=search_chain.run)
|
||||
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
|
||||
|
@ -5,7 +5,6 @@ from langchain.chains.llm_math.base import LLMMathChain
|
||||
from langchain.chains.pal.base import PALChain
|
||||
from langchain.chains.python import PythonChain
|
||||
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
|
||||
from langchain.chains.serpapi import SerpAPIChain
|
||||
from langchain.chains.sql_database.base import SQLDatabaseChain
|
||||
from langchain.chains.vector_db_qa.base import VectorDBQA
|
||||
|
||||
@ -13,7 +12,6 @@ __all__ = [
|
||||
"LLMChain",
|
||||
"LLMMathChain",
|
||||
"PythonChain",
|
||||
"SerpAPIChain",
|
||||
"SQLDatabaseChain",
|
||||
"VectorDBQA",
|
||||
"SequentialChain",
|
||||
|
@ -4,11 +4,10 @@ Heavily borrowed from https://github.com/ofirpress/self-ask
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
@ -26,8 +25,8 @@ class HiddenPrints:
|
||||
sys.stdout = self._original_stdout
|
||||
|
||||
|
||||
class SerpAPIChain(Chain, BaseModel):
|
||||
"""Chain that calls SerpAPI.
|
||||
class SerpAPIWrapper(BaseModel):
|
||||
"""Wrapper around SerpAPI.
|
||||
|
||||
To use, you should have the ``google-search-results`` python package installed,
|
||||
and the environment variable ``SERPAPI_API_KEY`` set with your API key, or pass
|
||||
@ -36,13 +35,11 @@ class SerpAPIChain(Chain, BaseModel):
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import SerpAPIChain
|
||||
serpapi = SerpAPIChain()
|
||||
from langchain import SerpAPIWrapper
|
||||
serpapi = SerpAPIWrapper()
|
||||
"""
|
||||
|
||||
search_engine: Any #: :meta private:
|
||||
input_key: str = "search_query" #: :meta private:
|
||||
output_key: str = "search_result" #: :meta private:
|
||||
|
||||
serpapi_api_key: Optional[str] = None
|
||||
|
||||
@ -51,22 +48,6 @@ class SerpAPIChain(Chain, BaseModel):
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the singular input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the singular output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
@ -85,11 +66,12 @@ class SerpAPIChain(Chain, BaseModel):
|
||||
)
|
||||
return values
|
||||
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
def run(self, query: str) -> str:
|
||||
"""Run query through SerpAPI and parse result."""
|
||||
params = {
|
||||
"api_key": self.serpapi_api_key,
|
||||
"engine": "google",
|
||||
"q": inputs[self.input_key],
|
||||
"q": query,
|
||||
"google_domain": "google.com",
|
||||
"gl": "us",
|
||||
"hl": "en",
|
||||
@ -112,4 +94,9 @@ class SerpAPIChain(Chain, BaseModel):
|
||||
toret = res["organic_results"][0]["snippet"]
|
||||
else:
|
||||
toret = "No good search result found"
|
||||
return {self.output_key: toret}
|
||||
return toret
|
||||
|
||||
|
||||
# For backwards compatability
|
||||
|
||||
SerpAPIWrapper = SerpAPIWrapper
|
@ -1,7 +1,7 @@
|
||||
"""Integration test for self ask with search."""
|
||||
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain
|
||||
from langchain.chains.serpapi import SerpAPIChain
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.serpapi import SerpAPIWrapper
|
||||
|
||||
|
||||
def test_self_ask_with_search() -> None:
|
||||
@ -9,7 +9,7 @@ def test_self_ask_with_search() -> None:
|
||||
question = "What is the hometown of the reigning men's U.S. Open champion?"
|
||||
chain = SelfAskWithSearchChain(
|
||||
llm=OpenAI(temperature=0),
|
||||
search_chain=SerpAPIChain(),
|
||||
search_chain=SerpAPIWrapper(),
|
||||
input_key="q",
|
||||
output_key="a",
|
||||
)
|
||||
|
@ -1,9 +1,9 @@
|
||||
"""Integration test for SerpAPI."""
|
||||
from langchain.chains.serpapi import SerpAPIChain
|
||||
from langchain.serpapi import SerpAPIWrapper
|
||||
|
||||
|
||||
def test_call() -> None:
|
||||
"""Test that call gives the correct answer."""
|
||||
chain = SerpAPIChain()
|
||||
chain = SerpAPIWrapper()
|
||||
output = chain.run("What was Obama's first name?")
|
||||
assert output == "Barack Hussein Obama II"
|
Loading…
Reference in New Issue
Block a user