This commit is contained in:
Eugene Yurtsev
2023-06-01 17:26:46 -04:00
parent f449d26083
commit b7dabe8f50
4 changed files with 122 additions and 23 deletions

View File

@@ -0,0 +1,92 @@
from __future__ import annotations
from typing import Optional, Any, Union, Literal
from plistlib import Dict
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun, AsyncCallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.research.readers import DocReadingChain, ParallelApplyChain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.research.search import GenericSearcher
from langchain.text_splitter import TextSplitter
class Research(Chain):
"""A simple research chain."""
searcher: GenericSearcher
"""The searcher to use to search for documents."""
reader: Chain
"""The reader to use to read documents and produce an answer."""
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the chain asynchronously."""
searcher = self.searcher
reader = self.reader
question = inputs["question"]
search_results = searcher({"question": question})
documents = [result["document"] for result in search_results["results"]]
if not documents:
return {"answer": None}
reader_inputs = {"doc": documents, "question": question}
return await reader(reader_inputs, callbacks=run_manager.get_child())
@classmethod
def from_llms(
cls,
link_selection_llm: BaseLanguageModel,
query_generation_llm: BaseLanguageModel,
qa_chain: LLMChain,
*,
top_k_per_search: int = -1,
max_concurrency: int = 1,
max_num_pages_per_doc: int = 100,
text_splitter: Union[TextSplitter, Literal["recursive"]] = "recursive",
) -> Research:
"""Helper to create a research chain from standard llm related components.
Args:
link_selection_llm: The language model to use for link selection.
query_generation_llm: The language model to use for query generation.
qa_chain: The chain to use to answer the question.
top_k_per_search: The number of documents to return per search.
max_concurrency: The maximum number of concurrent reads.
max_num_pages_per_doc: The maximum number of pages to read per document.
text_splitter: The text splitter to use to split the document into smaller chunks.
Returns:
A research chain.
"""
searcher = GenericSearcher.from_llms(
link_selection_llm,
query_generation_llm,
top_k_per_search=top_k_per_search,
)
if isinstance(text_splitter, str):
if text_splitter == "recursive":
_text_splitter = RecursiveCharacterTextSplitter()
else:
raise ValueError(f"Invalid text splitter: {text_splitter}")
elif isinstance(text_splitter, TextSplitter):
_text_splitter = text_splitter
else:
raise TypeError(f"Invalid text splitter: {type(text_splitter)}")
reader = ParallelApplyChain(
chain=DocReadingChain(
qa_chain, max_num_pages_per_doc=max_num_pages_per_doc,
text_splitter=text_splitter
),
max_concurrency=max_concurrency,
)
return cls(searcher=searcher, reader=reader)

View File

@@ -4,7 +4,6 @@ The main idea behind the crawling module is to identify additional links
that are worth exploring to find more documents that may be relevant for being
able to answer the question correctly.
"""
import json
import urllib.parse
from bs4 import PageElement, BeautifulSoup
from typing import List, Dict, Any, Tuple

View File

@@ -89,11 +89,13 @@ class DocReadingChain(Chain):
return {"document": summary_doc}
class ParallelApply(Chain):
"""Utility chain to apply a given chain in parallel across input documents."""
class ParallelApplyChain(Chain):
"""Utility chain to apply a given chain in parallel across input documents.
This chain needs to handle a limit on concurrency.
"""
chain: Chain
max_concurrency: int
@property
def input_keys(self) -> List[str]:

View File

@@ -1,9 +1,9 @@
"""Module for initiating a set of searches relevant for answering the question."""
from __future__ import annotations
import asyncio
import itertools
from bs4 import BeautifulSoup
from typing import Sequence, List, Mapping, Any, Optional, Union, Dict
from typing import Sequence, List, Mapping, Any, Optional, Dict
from langchain import PromptTemplate, LLMChain, serpapi
from langchain.base_language import BaseLanguageModel
@@ -11,9 +11,9 @@ from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.chains.classification.multiselection import MultiSelectChain
from langchain.schema import BaseOutputParser
from langchain.chains.base import Chain
def _extract_content_from_tag(html: str, tag: str) -> List[str]:
@@ -204,8 +204,8 @@ def make_query_generator(llm: BaseLanguageModel) -> LLMChain:
"""
return LLMChain(
llm=llm,
output_keys=["urls"],
promp=QUERY_GENERATION_PROMPT,
output_key="urls",
prompt=QUERY_GENERATION_PROMPT,
)
@@ -230,11 +230,12 @@ class GenericSearcher(Chain):
To extend implementation:
* Parameterize the search engine (to allow for non serp api search engines)
* Expose an abstract interface for query generator and link selection model
* Expose promoted answers from search engines as blobs that should be summarized
"""
query_generator: LLMChain
"""An LLM that is used to break down a complex question into a list of simpler queries."""
link_selection_model: LLMChain
link_selection_model: Chain
"""An LLM that is used to select the most relevant urls from the search results."""
top_k_per_search: int = -1
"""The number of top urls to select from each search."""
@@ -255,20 +256,21 @@ class GenericSearcher(Chain):
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
question = inputs["question"]
response = self.query_generator.predict_and_parse(
queries = self.query_generator.predict_and_parse(
callbacks=run_manager.get_child(), question=question
)
queries = response["queries"]
results = _run_searches(queries, top_k=self.top_k_per_search)
deuped_results = _deduplicate_objects(results, "link")
records = [
{"link": result["link"], "title": result["title"]}
for result in deuped_results
]
response_ = self.link_selection_model.predict_and_parse(
response_ = self.link_selection_model(
{
"question": question,
"choices": records,
},
callbacks=run_manager.get_child(),
question=question,
choices=records,
)
return {"urls": [result["link"] for result in response_["selected"]]}
@@ -278,33 +280,37 @@ class GenericSearcher(Chain):
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
question = inputs["question"]
response = await self.query_generator.apredict_and_parse(
queries = await self.query_generator.apredict_and_parse(
callbacks=run_manager.get_child(), question=question
)
queries = response["queries"]
results = await _arun_searches(queries)
results = _run_searches(queries, top_k=self.top_k_per_search)
deuped_results = _deduplicate_objects(results, "link")
records = [
{"link": result["link"], "title": result["title"]}
for result in deuped_results
]
response_ = self.link_selection_model.predict_and_parse(
response_ = await self.link_selection_model.acall(
{
"question": question,
"choices": records,
},
callbacks=run_manager.get_child(),
question=question,
choices=records,
)
return {"urls": [result["link"] for result in response_["selected"]]}
@classmethod
def from_llm(
def from_llms(
cls,
link_selection_llm: BaseLanguageModel,
query_generation_llm: BaseLanguageModel,
*,
top_k_per_search: int = -1,
) -> GenericSearcher:
"""Initialize the searcher from a language model."""
link_selection_model = MultiSelectChain.from_default(llm=link_selection_llm)
query_generation_model = make_query_generator(llm=query_generation_llm)
query_generation_model = make_query_generator(query_generation_llm)
return cls(
link_selection_model=link_selection_model,
query_generator=query_generation_model,
top_k_per_search=top_k_per_search,
)