mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-24 05:50:18 +00:00
x
This commit is contained in:
@@ -11,7 +11,7 @@ from langchain.callbacks.manager import (
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.research.fetch import DownloadHandler
|
||||
from langchain.chains.research.fetch import DownloadHandler, AutoDownloadHandler
|
||||
from langchain.chains.research.readers import DocReadingChain, ParallelApplyChain
|
||||
from langchain.chains.research.search import GenericSearcher
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
@@ -106,34 +106,34 @@ class Research(Chain):
|
||||
@classmethod
|
||||
def from_llms(
|
||||
cls,
|
||||
link_selection_llm: BaseLanguageModel,
|
||||
query_generation_llm: BaseLanguageModel,
|
||||
qa_chain: LLMChain,
|
||||
*,
|
||||
query_generation_llm: BaseLanguageModel,
|
||||
link_selection_llm: BaseLanguageModel,
|
||||
qa_chain: LLMChain,
|
||||
top_k_per_search: int = -1,
|
||||
max_concurrency: int = 1,
|
||||
max_num_pages_per_doc: int = 100,
|
||||
text_splitter: Union[TextSplitter, Literal["recursive"]] = "recursive",
|
||||
download_handler: Union[DownloadHandler, Literal["auto"]] = "auto",
|
||||
) -> Research:
|
||||
"""Helper to create a research chain from standard llm related components.
|
||||
|
||||
Args:
|
||||
link_selection_llm: The language model to use for link selection.
|
||||
query_generation_llm: The language model to use for query generation.
|
||||
link_selection_llm: The language model to use for link selection.
|
||||
qa_chain: The chain to use to answer the question.
|
||||
top_k_per_search: The number of documents to return per search.
|
||||
max_concurrency: The maximum number of concurrent reads.
|
||||
max_num_pages_per_doc: The maximum number of pages to read per document.
|
||||
text_splitter: The text splitter to use to split the document into smaller chunks.
|
||||
download_handler: The download handler to use to download the documents.
|
||||
Provide either a download handler or the name of a
|
||||
download handler.
|
||||
- "auto" swaps between using requests and playwright
|
||||
|
||||
Returns:
|
||||
A research chain.
|
||||
"""
|
||||
searcher = GenericSearcher.from_llms(
|
||||
link_selection_llm,
|
||||
query_generation_llm,
|
||||
top_k_per_search=top_k_per_search,
|
||||
)
|
||||
if isinstance(text_splitter, str):
|
||||
if text_splitter == "recursive":
|
||||
_text_splitter = RecursiveCharacterTextSplitter()
|
||||
@@ -143,12 +143,31 @@ class Research(Chain):
|
||||
_text_splitter = text_splitter
|
||||
else:
|
||||
raise TypeError(f"Invalid text splitter: {type(text_splitter)}")
|
||||
reader = ParallelApplyChain(
|
||||
chain=DocReadingChain(
|
||||
qa_chain,
|
||||
max_num_pages_per_doc=max_num_pages_per_doc,
|
||||
text_splitter=text_splitter,
|
||||
),
|
||||
|
||||
if isinstance(download_handler, str):
|
||||
if download_handler == "auto":
|
||||
_download_handler = AutoDownloadHandler()
|
||||
else:
|
||||
raise ValueError(f"Invalid download handler: {download_handler}")
|
||||
elif isinstance(download_handler, DownloadHandler):
|
||||
_download_handler = download_handler
|
||||
else:
|
||||
raise TypeError(f"Invalid download handler: {type(download_handler)}")
|
||||
|
||||
searcher = GenericSearcher.from_llms(
|
||||
link_selection_llm,
|
||||
query_generation_llm,
|
||||
top_k_per_search=top_k_per_search,
|
||||
)
|
||||
|
||||
doc_reading_chain = DocReadingChain(
|
||||
chain=qa_chain,
|
||||
max_num_pages_per_doc=max_num_pages_per_doc,
|
||||
text_splitter=_text_splitter,
|
||||
)
|
||||
# Can read multiple documents in parallel
|
||||
multi_reader = ParallelApplyChain(
|
||||
chain=doc_reading_chain,
|
||||
max_concurrency=max_concurrency,
|
||||
)
|
||||
return cls(searcher=searcher, reader=reader)
|
||||
return cls(searcher=searcher, reader=multi_reader, downloader=_download_handler)
|
||||
|
||||
@@ -12,7 +12,10 @@ from langchain.text_splitter import TextSplitter
|
||||
|
||||
|
||||
class DocReadingChain(Chain):
|
||||
"""A chain that reads the document.
|
||||
"""A reader chain should use one of the QA chains to answer a question.
|
||||
|
||||
This chain is also responsible for splitting the document into smaller chunks
|
||||
and then passing the chunks to an underlying QA chain.
|
||||
|
||||
A brute force chain that reads an entire document (or the first N pages).
|
||||
"""
|
||||
@@ -43,11 +46,8 @@ class DocReadingChain(Chain):
|
||||
) -> Dict[str, Any]:
|
||||
"""Process a long document synchronously."""
|
||||
source_document = inputs["doc"]
|
||||
|
||||
if not isinstance(source_document, Document):
|
||||
raise TypeError(f"Expected a Document, got {type(source_document)}")
|
||||
|
||||
question = inputs["question"]
|
||||
|
||||
sub_docs = self.text_splitter.split_documents([source_document])
|
||||
if self.max_num_docs > 0:
|
||||
_sub_docs = sub_docs[: self.max_num_docs]
|
||||
@@ -70,9 +70,9 @@ class DocReadingChain(Chain):
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Process a long document asynchronously."""
|
||||
doc = inputs["doc"]
|
||||
source_document = inputs["doc"]
|
||||
question = inputs["question"]
|
||||
sub_docs = self.text_splitter.split_documents([doc])
|
||||
sub_docs = self.text_splitter.split_documents([source_document])
|
||||
if self.max_num_docs > 0:
|
||||
_sub_docs = sub_docs[: self.max_num_docs]
|
||||
else:
|
||||
@@ -83,7 +83,7 @@ class DocReadingChain(Chain):
|
||||
)
|
||||
summary_doc = Document(
|
||||
page_content=results["output_text"],
|
||||
metadata=doc.metadata,
|
||||
metadata=source_document.metadata,
|
||||
)
|
||||
|
||||
return {"answer": summary_doc}
|
||||
|
||||
Reference in New Issue
Block a user