This commit is contained in:
Eugene Yurtsev
2023-06-02 12:12:17 -04:00
parent 39a2d2511d
commit 9c6accfa1a
2 changed files with 42 additions and 20 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
from typing import Optional, Any, Union, Literal, List, Dict, Mapping
import itertools
@@ -35,8 +36,12 @@ class Research(Chain):
to answer certain kinds of questions correctly.
4. A reader that reads the documents and produces an answer.
This research chain only implements a single hop at the moment; i.e.,
it goes from the questions to a list of URLs to documents to compiling answers.
Limitations:
* This research chain only implements a single hop at the moment; i.e.,
it goes from the questions to a list of URLs to documents to compiling answers.
* The reader chain needs to match the task. For example, if using a QA refine
chain, a task of collecting a list of entries from a long document will
fail because the QA refine chain is not designed to handle such a task.
The chain can be extended to continue crawling the documents in attempt
to discover relevant pages that were not surfaced by the search engine.
@@ -66,7 +71,7 @@ class Research(Chain):
@property
def output_keys(self) -> List[str]:
"""Return the output keys."""
return ["docs", "summary"]
return ["docs"]
def _call(
self,
@@ -75,14 +80,21 @@ class Research(Chain):
) -> Dict[str, Any]:
"""Run the chain synchronously."""
question = inputs["question"]
search_results = self.searcher({"question": question})
search_results = self.searcher(
{"question": question},
callbacks=run_manager.get_child() if run_manager else None,
)
urls = search_results["urls"]
blobs = self.downloader.download(urls)
raise NotImplementedError()
# docs = []
# docs = self.reader({"blobs": blobs})
# summary = self.summarizer({"docs": docs})
# return {"docs": docs, "summary": summary}
parser = MarkdownifyHTMLParser()
docs = itertools.chain.from_iterable(parser.lazy_parse(blob) for blob in blobs)
inputs = [{"doc": doc, "question": question} for doc in docs]
results = self.reader(
inputs, callbacks=run_manager.get_child() if run_manager else None
)
return {
"docs": [result["answer"] for result in results["inputs"]],
}
async def _acall(
self,
@@ -91,16 +103,21 @@ class Research(Chain):
) -> Dict[str, Any]:
"""Run the chain asynchronously."""
question = inputs["question"]
search_results = await self.searcher.acall({"question": question})
search_results = await self.searcher.acall(
{"question": question},
callbacks=run_manager.get_child() if run_manager else None,
)
urls = search_results["urls"]
blobs = await self.downloader.adownload(urls)
parser = MarkdownifyHTMLParser()
docs = itertools.chain.from_iterable(parser.lazy_parse(blob) for blob in blobs)
inputs = [{"doc": doc, "question": question} for doc in docs]
results = await self.reader.acall(inputs)
results = await self.reader.acall(
inputs,
callbacks=run_manager.get_child() if run_manager else None,
)
return {
"docs": [result["answer"] for result in results],
"summary": None,
"docs": [result["answer"] for result in results["results"]],
}
@classmethod
@@ -109,7 +126,7 @@ class Research(Chain):
*,
query_generation_llm: BaseLanguageModel,
link_selection_llm: BaseLanguageModel,
qa_chain: LLMChain,
underlying_reader_chain: LLMChain,
top_k_per_search: int = -1,
max_concurrency: int = 1,
max_num_pages_per_doc: int = 5,
@@ -122,7 +139,7 @@ class Research(Chain):
Args:
query_generation_llm: The language model to use for query generation.
link_selection_llm: The language model to use for link selection.
qa_chain: The chain to use to answer the question.
underlying_reader_chain: The chain to use to answer the question.
top_k_per_search: The number of documents to return per search.
max_concurrency: The maximum number of concurrent reads.
max_num_pages_per_doc: The maximum number of pages to read per document.
@@ -165,8 +182,8 @@ class Research(Chain):
)
doc_reading_chain = DocReadingChain(
chain=qa_chain,
max_num_pages_per_doc=max_num_pages_per_doc,
chain=underlying_reader_chain,
max_num_docs=max_num_pages_per_doc,
text_splitter=_text_splitter,
)
# Can read multiple documents in parallel

View File

@@ -26,8 +26,11 @@ class DocReadingChain(Chain):
text_splitter: TextSplitter
"""The text splitter to use to split the document into smaller chunks."""
max_num_docs: int = -1
"""The maximum number of documents to split the document into."""
max_num_docs: int
"""The maximum number of documents to split the document into.
Use -1 to denote no limit to the number of pages to read.
"""
@property
def input_keys(self) -> List[str]:
@@ -77,7 +80,7 @@ class DocReadingChain(Chain):
_sub_docs = sub_docs[: self.max_num_docs]
else:
_sub_docs = sub_docs
results = await self.chain.acall(
{"input_documents": _sub_docs, "question": question},
callbacks=run_manager.get_child(),
@@ -94,6 +97,8 @@ class ParallelApplyChain(Chain):
"""Utility chain to apply a given chain in parallel across input documents.
This chain needs to handle a limit on concurrency.
WARNING: Parallelization only implemented on the async path.
"""
chain: Chain