From 7e14388ba89ee6cf53bc9ec4723f4dabe0318450 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 1 Jun 2023 17:30:58 -0400 Subject: [PATCH] q --- langchain/chains/research/api.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/langchain/chains/research/api.py b/langchain/chains/research/api.py index ed2d62a31a4..a4ccaf287ae 100644 --- a/langchain/chains/research/api.py +++ b/langchain/chains/research/api.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Optional, Any, Union, Literal +from typing import Optional, Any, Union, Literal, List from plistlib import Dict @@ -11,6 +11,7 @@ from langchain.chains.research.readers import DocReadingChain, ParallelApplyChai from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.chains.research.search import GenericSearcher from langchain.text_splitter import TextSplitter +from langchain.chains.research.fetch import PlaywrightDownloadHandler, DownloadHandler class Research(Chain): @@ -19,6 +20,17 @@ class Research(Chain): """The searcher to use to search for documents.""" reader: Chain """The reader to use to read documents and produce an answer.""" + downloader: DownloadHandler + + @property + def input_keys(self) -> List[str]: + """Return the input keys.""" + return ["question"] + + @property + def output_keys(self) -> List[str]: + """Return the output keys.""" + return ["docs", "summary"] def _call( self, @@ -32,15 +44,11 @@ class Research(Chain): run_manager: Optional[AsyncCallbackManagerForChainRun] = None, ) -> Dict[str, Any]: """Run the chain asynchronously.""" - searcher = self.searcher - reader = self.reader question = inputs["question"] - search_results = searcher({"question": question}) - documents = [result["document"] for result in search_results["results"]] - if not documents: - return {"answer": None} - reader_inputs = {"doc": documents, "question": question} - return await reader(reader_inputs, callbacks=run_manager.get_child()) + search_results = self.searcher({"question": question}) + urls = search_results["urls"] + blobs = self.downloader.download(urls) + raise NotImplementedError() @classmethod def from_llms(