mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-24 05:50:18 +00:00
q
This commit is contained in:
@@ -3,13 +3,17 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import itertools
|
||||
from bs4 import BeautifulSoup
|
||||
from typing import Sequence, List, Mapping, Any, Optional
|
||||
from typing import Sequence, List, Mapping, Any, Optional, Union, Dict
|
||||
|
||||
from langchain import PromptTemplate, LLMChain, serpapi
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.classification.multiselection import MultiSelectChain
|
||||
from langchain.chains.research.typedefs import AbstractQueryGenerator, AbstractSearcher
|
||||
from langchain.schema import BaseOutputParser
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
|
||||
def _extract_content_from_tag(html: str, tag: str) -> List[str]:
|
||||
@@ -47,7 +51,6 @@ class QueryExtractor(BaseOutputParser[List[str]]):
|
||||
return _extract_content_from_tag(text, "query")
|
||||
|
||||
|
||||
# TODO(Eugene): add a version that works for chat models as well w/ human system message?
|
||||
QUERY_GENERATION_PROMPT = PromptTemplate.from_template(
|
||||
"""\
|
||||
Suggest a few different search queries that could be used to identify web-pages \
|
||||
@@ -121,11 +124,15 @@ def _deduplicate_objects(
|
||||
return deduped
|
||||
|
||||
|
||||
def _run_searches(queries: Sequence[str]) -> List[Mapping[str, Any]]:
|
||||
"""Run the given queries and return the unique results.
|
||||
def _run_searches(queries: Sequence[str], top_k: int = -1) -> List[Mapping[str, Any]]:
|
||||
"""Run the given queries and return all the search results.
|
||||
|
||||
This function can return duplicated results, de-duplication can take place later
|
||||
and take into account the frequency of appearance.
|
||||
|
||||
Args:
|
||||
queries: a list of queries to run
|
||||
top_k: the number of results to return, if -1 return all results
|
||||
|
||||
Returns:
|
||||
a list of unique search results
|
||||
@@ -134,26 +141,50 @@ def _run_searches(queries: Sequence[str]) -> List[Mapping[str, Any]]:
|
||||
results = []
|
||||
for query in queries:
|
||||
result = wrapper.results(query)
|
||||
organic_results = result["organic_results"]
|
||||
if top_k <= 0:
|
||||
organic_results = result["organic_results"]
|
||||
else:
|
||||
organic_results = result["organic_results"][:top_k]
|
||||
results.extend(organic_results)
|
||||
return results
|
||||
|
||||
|
||||
async def _arun_searches(queries: Sequence[str]) -> List[Mapping[str, Any]]:
|
||||
"""Run the given queries and return the unique results."""
|
||||
async def _arun_searches(
|
||||
queries: Sequence[str], top_k: int = -1
|
||||
) -> List[Mapping[str, Any]]:
|
||||
"""Run the given queries and return all the search results.
|
||||
|
||||
This function can return duplicated results, de-duplication can take place later
|
||||
and take into account the frequency of appearance.
|
||||
|
||||
Args:
|
||||
queries: a list of queries to run
|
||||
top_k: the number of results to return, if -1 return all results
|
||||
|
||||
Returns:
|
||||
a list of unique search results
|
||||
"""
|
||||
wrapper = serpapi.SerpAPIWrapper()
|
||||
tasks = [wrapper.results(query) for query in queries]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
return list(
|
||||
itertools.chain.from_iterable(result["organic_results"] for result in results)
|
||||
)
|
||||
finalized_results = []
|
||||
|
||||
for result in results:
|
||||
if top_k <= 0:
|
||||
organic_results = result["organic_results"]
|
||||
else:
|
||||
organic_results = result["organic_results"][:top_k]
|
||||
|
||||
finalized_results.extend(organic_results)
|
||||
|
||||
return finalized_results
|
||||
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
class QueryGenerator(AbstractQueryGenerator):
|
||||
def make_query_generator(llm: BaseLanguageModel) -> LLMChain:
|
||||
"""Use an LLM to break down a complex query into a list of simpler queries.
|
||||
|
||||
The simpler queries are used to run searches against a search engine in a goal
|
||||
@@ -162,82 +193,118 @@ class QueryGenerator(AbstractQueryGenerator):
|
||||
|
||||
Query:
|
||||
|
||||
Does Harrison Chase of Langchain like to eat pizza or play squash?
|
||||
Does Harrison Chase of Langchain like to eat pizza or play squash?
|
||||
|
||||
May be broken down into a list of simpler queries like:
|
||||
|
||||
- Harrison Chase
|
||||
- Harrison Chase Langchain
|
||||
- Harrison Chase Langchain pizza
|
||||
- Harrison Chase Langchain squash
|
||||
- Harrison Chase
|
||||
- Harrison Chase Langchain
|
||||
- Harrison Chase Langchain pizza
|
||||
- Harrison Chase Langchain squash
|
||||
"""
|
||||
return LLMChain(
|
||||
llm=llm,
|
||||
output_keys=["urls"],
|
||||
promp=QUERY_GENERATION_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
class GenericSearcher(Chain):
|
||||
"""A chain that takes a complex question and identifies a list of relevant urls.
|
||||
|
||||
The chain works by:
|
||||
|
||||
1. Breaking a complex question into a series of simpler queries using an LLM.
|
||||
2. Running the queries against a search engine.
|
||||
3. Selecting the most relevant urls using an LLM (can be replaced with tf-idf or other models).
|
||||
|
||||
This chain is not meant to be used for questions requiring multiple hops to answer.
|
||||
|
||||
For example, the age of leonardo dicaprio's girlfriend is a multi-hop question
|
||||
|
||||
This kind of question requires a slightly different approach.
|
||||
|
||||
This chain is meant to handle questions for which one wants
|
||||
to collect information from multiple sources.
|
||||
|
||||
To extend implementation:
|
||||
* Parameterize the search engine (to allow for non serp api search engines)
|
||||
* Expose an abstract interface for query generator and link selection model
|
||||
"""
|
||||
|
||||
def __init__(self, llm: BaseLanguageModel) -> None:
|
||||
"""Initialize the query generator."""
|
||||
self.chain = LLMChain(prompt=QUERY_GENERATION_PROMPT, llm=llm)
|
||||
query_generator: LLMChain
|
||||
"""An LLM that is used to break down a complex question into a list of simpler queries."""
|
||||
link_selection_model: LLMChain
|
||||
"""An LLM that is used to select the most relevant urls from the search results."""
|
||||
top_k_per_search: int = -1
|
||||
"""The number of top urls to select from each search."""
|
||||
|
||||
def generate_queries(self, question: str) -> List[str]:
|
||||
"""Generate queries using a Chain."""
|
||||
queries = self.chain.predict_and_parse(question=question)
|
||||
return queries
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys."""
|
||||
return ["question"]
|
||||
|
||||
async def agenerate_queries(self, question: str) -> List[str]:
|
||||
"""Generate queries using a Chain."""
|
||||
queries = await self.chain.apredict_and_parse(question=question)
|
||||
return queries
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys."""
|
||||
return ["urls"]
|
||||
|
||||
|
||||
class GenericSearcher(AbstractSearcher):
|
||||
"""A generic searcher implementation.
|
||||
|
||||
Generic searcher is parameterized by the search engine to allow for
|
||||
extending it to other search engines in the future.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
def _call(
|
||||
self,
|
||||
*,
|
||||
search_engine: str = "serp",
|
||||
link_selection_model: Optional[MultiSelectChain] = None,
|
||||
) -> None:
|
||||
"""Initialize the searcher.
|
||||
|
||||
Args:
|
||||
search_engine: the search engine to use. Placeholder for future work.
|
||||
link_selection_model: the link selection model to use.
|
||||
"""
|
||||
if search_engine != "serp":
|
||||
raise NotImplementedError("Only serp is supported at the moment.")
|
||||
self.link_selection_model = link_selection_model
|
||||
|
||||
async def asearch(self, queries: Sequence[str]) -> List[Mapping[str, Any]]:
|
||||
"""Run searchers for given sequence of queries.
|
||||
|
||||
Args:
|
||||
queries: a list of queries to run
|
||||
|
||||
Returns:
|
||||
a list of search results
|
||||
"""
|
||||
results = await _arun_searches(queries)
|
||||
deduplicated_results = _deduplicate_objects(results, "link")
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
question = inputs["question"]
|
||||
response = self.query_generator.predict_and_parse(
|
||||
callbacks=run_manager.get_child(), question=question
|
||||
)
|
||||
queries = response["queries"]
|
||||
results = _run_searches(queries, top_k=self.top_k_per_search)
|
||||
deuped_results = _deduplicate_objects(results, "link")
|
||||
records = [
|
||||
{
|
||||
"title": result["title"],
|
||||
"snippet": result["snippet"],
|
||||
"link": result["link"],
|
||||
}
|
||||
for result in deduplicated_results
|
||||
{"link": result["link"], "title": result["title"]}
|
||||
for result in deuped_results
|
||||
]
|
||||
response_ = self.link_selection_model.predict_and_parse(
|
||||
callbacks=run_manager.get_child(),
|
||||
question=question,
|
||||
choices=records,
|
||||
)
|
||||
return {"urls": [result["link"] for result in response_["selected"]]}
|
||||
|
||||
return records
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
question = inputs["question"]
|
||||
response = await self.query_generator.apredict_and_parse(
|
||||
callbacks=run_manager.get_child(), question=question
|
||||
)
|
||||
queries = response["queries"]
|
||||
results = await _arun_searches(queries)
|
||||
deuped_results = _deduplicate_objects(results, "link")
|
||||
records = [
|
||||
{"link": result["link"], "title": result["title"]}
|
||||
for result in deuped_results
|
||||
]
|
||||
response_ = self.link_selection_model.predict_and_parse(
|
||||
callbacks=run_manager.get_child(),
|
||||
question=question,
|
||||
choices=records,
|
||||
)
|
||||
return {"urls": [result["link"] for result in response_["selected"]]}
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls, link_selection_llm: BaseLanguageModel, *, search_engine: str = "serp"
|
||||
cls,
|
||||
link_selection_llm: BaseLanguageModel,
|
||||
query_generation_llm: BaseLanguageModel,
|
||||
) -> GenericSearcher:
|
||||
"""Initialize the searcher from a language model."""
|
||||
link_selection_model = MultiSelectChain.from_default(llm=link_selection_llm)
|
||||
query_generation_model = make_query_generator(llm=query_generation_llm)
|
||||
return cls(
|
||||
search_engine=search_engine, link_selection_model=link_selection_model
|
||||
link_selection_model=link_selection_model,
|
||||
query_generator=query_generation_model,
|
||||
)
|
||||
|
||||
@@ -5,52 +5,52 @@ import abc
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.document_loaders.blob_loaders import Blob
|
||||
|
||||
|
||||
class AbstractQueryGenerator(abc.ABC):
|
||||
"""Abstract class for generating queries."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def generate_queries(self, question: str, callbacks: Callbacks = None) -> List[str]:
|
||||
"""Generate queries for the given question."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def agenerate_queries(
|
||||
self, question: str, callbacks: Callbacks = None
|
||||
) -> List[str]:
|
||||
"""Generate queries for the given question."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class AbstractSearcher(abc.ABC):
|
||||
"""Abstract class for running searches."""
|
||||
|
||||
def search(self, queries: Sequence[str]) -> List[Mapping[str, Any]]:
|
||||
"""Run a search for the given query.
|
||||
|
||||
Args:
|
||||
queries: the query to run the search for.
|
||||
|
||||
Returns:
|
||||
a list of search results.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def asearch(self, queries: Sequence[str]) -> List[Mapping[str, Any]]:
|
||||
"""Run a search for the given query.
|
||||
|
||||
Args:
|
||||
queries: the query to run the search for.
|
||||
|
||||
Returns:
|
||||
a list of search results.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class BlobCrawler(abc.ABC):
|
||||
"""Crawl a blob and identify links to related content."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def crawl(self, blob: Blob, query: str, callbacks: Callbacks = None) -> List[str]:
|
||||
"""Explore the blob and identify links to related content that is relevant to the query."""
|
||||
#
|
||||
# class AbstractQueryGenerator(abc.ABC):
|
||||
# """Abstract class for generating queries."""
|
||||
#
|
||||
# @abc.abstractmethod
|
||||
# def generate_queries(self, question: str, callbacks: Callbacks = None) -> List[str]:
|
||||
# """Generate queries for the given question."""
|
||||
# raise NotImplementedError()
|
||||
#
|
||||
# @abc.abstractmethod
|
||||
# async def agenerate_queries(
|
||||
# self, question: str, callbacks: Callbacks = None
|
||||
# ) -> List[str]:
|
||||
# """Generate queries for the given question."""
|
||||
# raise NotImplementedError()
|
||||
#
|
||||
#
|
||||
# class AbstractSearcher(abc.ABC):
|
||||
# """Abstract class for running searches."""
|
||||
#
|
||||
# def search(self, queries: Sequence[str]) -> List[Mapping[str, Any]]:
|
||||
# """Run a search for the given query.
|
||||
#
|
||||
# Args:
|
||||
# queries: the query to run the search for.
|
||||
#
|
||||
# Returns:
|
||||
# a list of search results.
|
||||
# """
|
||||
# raise NotImplementedError()
|
||||
#
|
||||
# async def asearch(self, queries: Sequence[str]) -> List[Mapping[str, Any]]:
|
||||
# """Run a search for the given query.
|
||||
#
|
||||
# Args:
|
||||
# queries: the query to run the search for.
|
||||
#
|
||||
# Returns:
|
||||
# a list of search results.
|
||||
# """
|
||||
# raise NotImplementedError()
|
||||
#
|
||||
#
|
||||
# class BlobCrawler(abc.ABC):
|
||||
# """Crawl a blob and identify links to related content."""
|
||||
#
|
||||
# @abc.abstractmethod
|
||||
# def crawl(self, blob: Blob, query: str, callbacks: Callbacks = None) -> List[str]:
|
||||
# """Explore the blob and identify links to related content that is relevant to the query."""
|
||||
|
||||
Reference in New Issue
Block a user