mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-24 05:50:18 +00:00
q
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, Any, Union, Literal, List, Dict, Mapping
|
||||
import itertools
|
||||
|
||||
@@ -35,8 +36,12 @@ class Research(Chain):
|
||||
to answer certain kinds of questions correctly.
|
||||
4. A reader that reads the documents and produces an answer.
|
||||
|
||||
This research chain only implements a single hop at the moment; i.e.,
|
||||
it goes from the questions to a list of URLs to documents to compiling answers.
|
||||
Limitations:
|
||||
* This research chain only implements a single hop at the moment; i.e.,
|
||||
it goes from the questions to a list of URLs to documents to compiling answers.
|
||||
* The reader chain needs to match the task. For example, if using a QA refine
|
||||
chain, a task of collecting a list of entries from a long document will
|
||||
fail because the QA refine chain is not designed to handle such a task.
|
||||
|
||||
The chain can be extended to continue crawling the documents in attempt
|
||||
to discover relevant pages that were not surfaced by the search engine.
|
||||
@@ -66,7 +71,7 @@ class Research(Chain):
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys."""
|
||||
return ["docs", "summary"]
|
||||
return ["docs"]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
@@ -75,14 +80,21 @@ class Research(Chain):
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the chain synchronously."""
|
||||
question = inputs["question"]
|
||||
search_results = self.searcher({"question": question})
|
||||
search_results = self.searcher(
|
||||
{"question": question},
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
urls = search_results["urls"]
|
||||
blobs = self.downloader.download(urls)
|
||||
raise NotImplementedError()
|
||||
# docs = []
|
||||
# docs = self.reader({"blobs": blobs})
|
||||
# summary = self.summarizer({"docs": docs})
|
||||
# return {"docs": docs, "summary": summary}
|
||||
parser = MarkdownifyHTMLParser()
|
||||
docs = itertools.chain.from_iterable(parser.lazy_parse(blob) for blob in blobs)
|
||||
inputs = [{"doc": doc, "question": question} for doc in docs]
|
||||
results = self.reader(
|
||||
inputs, callbacks=run_manager.get_child() if run_manager else None
|
||||
)
|
||||
return {
|
||||
"docs": [result["answer"] for result in results["inputs"]],
|
||||
}
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
@@ -91,16 +103,21 @@ class Research(Chain):
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the chain asynchronously."""
|
||||
question = inputs["question"]
|
||||
search_results = await self.searcher.acall({"question": question})
|
||||
search_results = await self.searcher.acall(
|
||||
{"question": question},
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
urls = search_results["urls"]
|
||||
blobs = await self.downloader.adownload(urls)
|
||||
parser = MarkdownifyHTMLParser()
|
||||
docs = itertools.chain.from_iterable(parser.lazy_parse(blob) for blob in blobs)
|
||||
inputs = [{"doc": doc, "question": question} for doc in docs]
|
||||
results = await self.reader.acall(inputs)
|
||||
results = await self.reader.acall(
|
||||
inputs,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
return {
|
||||
"docs": [result["answer"] for result in results],
|
||||
"summary": None,
|
||||
"docs": [result["answer"] for result in results["results"]],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -109,7 +126,7 @@ class Research(Chain):
|
||||
*,
|
||||
query_generation_llm: BaseLanguageModel,
|
||||
link_selection_llm: BaseLanguageModel,
|
||||
qa_chain: LLMChain,
|
||||
underlying_reader_chain: LLMChain,
|
||||
top_k_per_search: int = -1,
|
||||
max_concurrency: int = 1,
|
||||
max_num_pages_per_doc: int = 5,
|
||||
@@ -122,7 +139,7 @@ class Research(Chain):
|
||||
Args:
|
||||
query_generation_llm: The language model to use for query generation.
|
||||
link_selection_llm: The language model to use for link selection.
|
||||
qa_chain: The chain to use to answer the question.
|
||||
underlying_reader_chain: The chain to use to answer the question.
|
||||
top_k_per_search: The number of documents to return per search.
|
||||
max_concurrency: The maximum number of concurrent reads.
|
||||
max_num_pages_per_doc: The maximum number of pages to read per document.
|
||||
@@ -165,8 +182,8 @@ class Research(Chain):
|
||||
)
|
||||
|
||||
doc_reading_chain = DocReadingChain(
|
||||
chain=qa_chain,
|
||||
max_num_pages_per_doc=max_num_pages_per_doc,
|
||||
chain=underlying_reader_chain,
|
||||
max_num_docs=max_num_pages_per_doc,
|
||||
text_splitter=_text_splitter,
|
||||
)
|
||||
# Can read multiple documents in parallel
|
||||
|
||||
@@ -26,8 +26,11 @@ class DocReadingChain(Chain):
|
||||
text_splitter: TextSplitter
|
||||
"""The text splitter to use to split the document into smaller chunks."""
|
||||
|
||||
max_num_docs: int = -1
|
||||
"""The maximum number of documents to split the document into."""
|
||||
max_num_docs: int
|
||||
"""The maximum number of documents to split the document into.
|
||||
|
||||
Use -1 to denote no limit to the number of pages to read.
|
||||
"""
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
@@ -77,7 +80,7 @@ class DocReadingChain(Chain):
|
||||
_sub_docs = sub_docs[: self.max_num_docs]
|
||||
else:
|
||||
_sub_docs = sub_docs
|
||||
|
||||
|
||||
results = await self.chain.acall(
|
||||
{"input_documents": _sub_docs, "question": question},
|
||||
callbacks=run_manager.get_child(),
|
||||
@@ -94,6 +97,8 @@ class ParallelApplyChain(Chain):
|
||||
"""Utility chain to apply a given chain in parallel across input documents.
|
||||
|
||||
This chain needs to handle a limit on concurrency.
|
||||
|
||||
WARNING: Parallelization only implemented on the async path.
|
||||
"""
|
||||
|
||||
chain: Chain
|
||||
|
||||
Reference in New Issue
Block a user