This commit is contained in:
Eugene Yurtsev
2023-06-01 23:14:11 -04:00
parent f24e521015
commit 679eb9f14f
2 changed files with 44 additions and 25 deletions

View File

@@ -11,7 +11,7 @@ from langchain.callbacks.manager import (
)
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.research.fetch import DownloadHandler
from langchain.chains.research.fetch import DownloadHandler, AutoDownloadHandler
from langchain.chains.research.readers import DocReadingChain, ParallelApplyChain
from langchain.chains.research.search import GenericSearcher
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
@@ -106,34 +106,34 @@ class Research(Chain):
@classmethod
def from_llms(
cls,
link_selection_llm: BaseLanguageModel,
query_generation_llm: BaseLanguageModel,
qa_chain: LLMChain,
*,
query_generation_llm: BaseLanguageModel,
link_selection_llm: BaseLanguageModel,
qa_chain: LLMChain,
top_k_per_search: int = -1,
max_concurrency: int = 1,
max_num_pages_per_doc: int = 100,
text_splitter: Union[TextSplitter, Literal["recursive"]] = "recursive",
download_handler: Union[DownloadHandler, Literal["auto"]] = "auto",
) -> Research:
"""Helper to create a research chain from standard llm related components.
Args:
link_selection_llm: The language model to use for link selection.
query_generation_llm: The language model to use for query generation.
link_selection_llm: The language model to use for link selection.
qa_chain: The chain to use to answer the question.
top_k_per_search: The number of documents to return per search.
max_concurrency: The maximum number of concurrent reads.
max_num_pages_per_doc: The maximum number of pages to read per document.
text_splitter: The text splitter to use to split the document into smaller chunks.
download_handler: The download handler to use to download the documents.
Provide either a download handler or the name of a
download handler.
- "auto" swaps between using requests and playwright
Returns:
A research chain.
"""
searcher = GenericSearcher.from_llms(
link_selection_llm,
query_generation_llm,
top_k_per_search=top_k_per_search,
)
if isinstance(text_splitter, str):
if text_splitter == "recursive":
_text_splitter = RecursiveCharacterTextSplitter()
@@ -143,12 +143,31 @@ class Research(Chain):
_text_splitter = text_splitter
else:
raise TypeError(f"Invalid text splitter: {type(text_splitter)}")
reader = ParallelApplyChain(
chain=DocReadingChain(
qa_chain,
max_num_pages_per_doc=max_num_pages_per_doc,
text_splitter=text_splitter,
),
if isinstance(download_handler, str):
if download_handler == "auto":
_download_handler = AutoDownloadHandler()
else:
raise ValueError(f"Invalid download handler: {download_handler}")
elif isinstance(download_handler, DownloadHandler):
_download_handler = download_handler
else:
raise TypeError(f"Invalid download handler: {type(download_handler)}")
searcher = GenericSearcher.from_llms(
link_selection_llm,
query_generation_llm,
top_k_per_search=top_k_per_search,
)
doc_reading_chain = DocReadingChain(
chain=qa_chain,
max_num_pages_per_doc=max_num_pages_per_doc,
text_splitter=_text_splitter,
)
# Can read multiple documents in parallel
multi_reader = ParallelApplyChain(
chain=doc_reading_chain,
max_concurrency=max_concurrency,
)
return cls(searcher=searcher, reader=reader)
return cls(searcher=searcher, reader=multi_reader, downloader=_download_handler)

View File

@@ -12,7 +12,10 @@ from langchain.text_splitter import TextSplitter
class DocReadingChain(Chain):
"""A chain that reads the document.
"""A reader chain should use one of the QA chains to answer a question.
This chain is also responsible for splitting the document into smaller chunks
and then passing the chunks to an underlying QA chain.
A brute force chain that reads an entire document (or the first N pages).
"""
@@ -43,11 +46,8 @@ class DocReadingChain(Chain):
) -> Dict[str, Any]:
"""Process a long document synchronously."""
source_document = inputs["doc"]
if not isinstance(source_document, Document):
raise TypeError(f"Expected a Document, got {type(source_document)}")
question = inputs["question"]
sub_docs = self.text_splitter.split_documents([source_document])
if self.max_num_docs > 0:
_sub_docs = sub_docs[: self.max_num_docs]
@@ -70,9 +70,9 @@ class DocReadingChain(Chain):
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Process a long document asynchronously."""
doc = inputs["doc"]
source_document = inputs["doc"]
question = inputs["question"]
sub_docs = self.text_splitter.split_documents([doc])
sub_docs = self.text_splitter.split_documents([source_document])
if self.max_num_docs > 0:
_sub_docs = sub_docs[: self.max_num_docs]
else:
@@ -83,7 +83,7 @@ class DocReadingChain(Chain):
)
summary_doc = Document(
page_content=results["output_text"],
metadata=doc.metadata,
metadata=source_document.metadata,
)
return {"answer": summary_doc}