This commit is contained in:
Eugene Yurtsev
2023-06-01 17:30:58 -04:00
parent b7dabe8f50
commit 7e14388ba8

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Optional, Any, Union, Literal
from typing import Optional, Any, Union, Literal, List
from plistlib import Dict
@@ -11,6 +11,7 @@ from langchain.chains.research.readers import DocReadingChain, ParallelApplyChai
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.research.search import GenericSearcher
from langchain.text_splitter import TextSplitter
from langchain.chains.research.fetch import PlaywrightDownloadHandler, DownloadHandler
class Research(Chain):
@@ -19,6 +20,17 @@ class Research(Chain):
"""The searcher to use to search for documents."""
reader: Chain
"""The reader to use to read documents and produce an answer."""
downloader: DownloadHandler
@property
def input_keys(self) -> List[str]:
"""Return the input keys."""
return ["question"]
@property
def output_keys(self) -> List[str]:
"""Return the output keys."""
return ["docs", "summary"]
def _call(
self,
@@ -32,15 +44,11 @@ class Research(Chain):
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the chain asynchronously."""
searcher = self.searcher
reader = self.reader
question = inputs["question"]
search_results = searcher({"question": question})
documents = [result["document"] for result in search_results["results"]]
if not documents:
return {"answer": None}
reader_inputs = {"doc": documents, "question": question}
return await reader(reader_inputs, callbacks=run_manager.get_child())
search_results = self.searcher({"question": question})
urls = search_results["urls"]
blobs = self.downloader.download(urls)
raise NotImplementedError()
@classmethod
def from_llms(