This commit is contained in:
Eugene Yurtsev
2023-05-31 09:52:21 -04:00
parent d381c4fad8
commit 6a558c72a2

View File

@@ -1,12 +1,13 @@
"""Module for initiating a set of searches relevant for answering the question."""
from __future__ import annotations
import asyncio
import itertools
from bs4 import BeautifulSoup
from typing import Sequence, List, Mapping, Any
from typing import Sequence, List, Mapping, Any, Optional
from langchain import PromptTemplate, LLMChain, serpapi
from langchain.chains import LLMBashChain
from langchain.base_language import BaseLanguageModel
from langchain.chains.classification.multiselection import MultiSelectChain
from langchain.chains.research.typedefs import AbstractQueryGenerator, AbstractSearcher
from langchain.schema import BaseOutputParser
@@ -152,17 +153,37 @@ async def _arun_searches(queries: Sequence[str]) -> List[Mapping[str, Any]]:
# PUBLIC API
class LLMBasedQueryGenerator(AbstractQueryGenerator):
"""A generic query generator implementation."""
class QueryGenerator(AbstractQueryGenerator):
"""Use an LLM to break down a complex query into a list of simpler queries.
The simpler queries are used to run searches against a search engine in a goal
to increase the recall step and retrieve as many relevant results as possible.
Query:
Does Harrison Chase of Langchain like to eat pizza or play squash?
May be broken down into a list of simpler queries like:
- Harrison Chase
- Harrison Chase Langchain
- Harrison Chase Langchain pizza
- Harrison Chase Langchain squash
"""
def __init__(self, llm: BaseLanguageModel) -> None:
"""Initialize the query generator."""
self.llm = llm
self.chain = LLMChain(prompt=QUERY_GENERATION_PROMPT, llm=llm)
def generate_queries(self, question: str) -> List[str]:
"""Generate queries using a Chain."""
chain = LLMChain(prompt=QUERY_GENERATION_PROMPT, llm=self.llm)
queries = chain.predict_and_parse(question=question)
queries = self.chain.predict_and_parse(question=question)
return queries
async def agenerate_queries(self, question: str) -> List[str]:
"""Generate queries using a Chain."""
queries = await self.chain.apredict_and_parse(question=question)
return queries
@@ -175,21 +196,22 @@ class GenericSearcher(AbstractSearcher):
def __init__(
self,
link_selection_llm,
*,
search_engine: str = "serp",
link_selection_model: Optional[MultiSelectChain] = None,
) -> None:
"""Initialize the searcher.
Args:
link_selection_llm: the language model to use for link selection.
search_engine: the search engine to use. Placeholder for future work.
link_selection_model: the link selection model to use.
"""
if search_engine != "serp":
raise NotImplementedError("Only serp is supported at the moment.")
self.link_selection_llm = link_selection_llm
self.link_selection_model = link_selection_model
def asearch(self, queries: Sequence[str]) -> List[Mapping[str, Any]]:
"""Run a search for the given query.
async def asearch(self, queries: Sequence[str]) -> List[Mapping[str, Any]]:
"""Run searchers for given sequence of queries.
Args:
queries: a list of queries to run
@@ -209,3 +231,13 @@ class GenericSearcher(AbstractSearcher):
]
return records
@classmethod
def from_llm(
cls, link_selection_llm: BaseLanguageModel, *, search_engine: str = "serp"
) -> GenericSearcher:
"""Initialize the searcher from a language model."""
link_selection_model = MultiSelectChain.from_default(llm=link_selection_llm)
return cls(
search_engine=search_engine, link_selection_model=link_selection_model
)