This commit is contained in:
Eugene Yurtsev
2023-06-02 14:37:57 -04:00
parent 37bdeb60fc
commit 1911388d9d
5 changed files with 31 additions and 19 deletions

View File

@@ -132,8 +132,17 @@ def _write_records_to_string(
T = TypeVar("T")
def batch(iterable: Iterable[T], size) -> Iterator[List[T]]:
"""Batch an iterable into chunks of size `size`."""
def _batch(iterable: Iterable[T], size: int) -> Iterator[List[T]]:
"""Batch an iterable into chunks of size `size`.
Args:
iterable: the iterable to batch
size: the size of each batch
Returns:
iterator over batches of size `size` except for last batch which will be up
to size `size`
"""
iterator = iter(iterable)
while True:
batch = list(islice(iterator, size))
@@ -167,12 +176,12 @@ class MultiSelectChain(Chain):
question = inputs["question"]
columns = inputs.get("columns", None)
selected = []
selected: List[Mapping[str, Any]] = []
# TODO(): Balance choices into equal batches with constraint dependent
# on context window and prompt
max_choices = 30
for choice_batch in batch(choices, max_choices):
for choice_batch in _batch(choices, max_choices):
records_with_ids = [
{**record, "id": idx} for idx, record in enumerate(choice_batch)
]
@@ -185,7 +194,7 @@ class MultiSelectChain(Chain):
self.llm_chain.predict_and_parse(
records=records_str,
question=question,
callbacks=run_manager.get_child(),
callbacks=run_manager.get_child() if run_manager else None,
),
)
valid_indexes = [idx for idx in indexes if 0 <= idx < len(choice_batch)]
@@ -204,12 +213,12 @@ class MultiSelectChain(Chain):
question = inputs["question"]
columns = inputs.get("columns", None)
selected = []
selected: List[Mapping[str, Any]] = []
# TODO(): Balance choices into equal batches with constraint dependent
# on context window and prompt
max_choices = 30
for choice_batch in batch(choices, max_choices):
for choice_batch in _batch(choices, max_choices):
records_with_ids = [
{**record, "id": idx} for idx, record in enumerate(choice_batch)
]
@@ -222,7 +231,7 @@ class MultiSelectChain(Chain):
await self.llm_chain.apredict_and_parse(
records=records_str,
question=question,
callbacks=run_manager.get_child(),
callbacks=run_manager.get_child() if run_manager else None,
),
)
valid_indexes = [idx for idx in indexes if 0 <= idx < len(choice_batch)]

View File

@@ -36,7 +36,8 @@ class Research(Chain):
Limitations:
* This research chain only implements a single hop at the moment; i.e.,
it goes from the questions to a list of URLs to documents to compiling answers.
it goes from the questions to a list of URLs to documents to compiling
answers.
* The reader chain needs to match the task. For example, if using a QA refine
chain, a task of collecting a list of entries from a long document will
fail because the QA refine chain is not designed to handle such a task.
@@ -145,7 +146,8 @@ class Research(Chain):
top_k_per_search: The number of documents to return per search.
max_concurrency: The maximum number of concurrent reads.
max_num_pages_per_doc: The maximum number of pages to read per document.
text_splitter: The text splitter to use to split the document into smaller chunks.
text_splitter: The text splitter to use to split the document into
smaller chunks.
download_handler: The download handler to use to download the documents.
Provide either a download handler or the name of a
download handler.

View File

@@ -70,7 +70,7 @@ class PlaywrightDownloadHandler(DownloadHandler):
"""
self.timeout = timeout
def download(self, urls: Sequence[str]) -> List[Blob]:
def download(self, urls: Sequence[str]) -> List[MaybeBlob]:
"""Download list of urls synchronously."""
return asyncio.run(self.adownload(urls))
@@ -87,7 +87,7 @@ class PlaywrightDownloadHandler(DownloadHandler):
html_content = None
return html_content
async def adownload(self, urls: Sequence[str]) -> List[Blob]:
async def adownload(self, urls: Sequence[str]) -> List[MaybeBlob]:
"""Download a batch of URLs asynchronously using playwright.
Args:
@@ -112,7 +112,7 @@ class RequestsDownloadHandler(DownloadHandler):
"""Initialize the requests download handler."""
self.web_downloader = web_downloader or WebBaseLoader(web_path=[])
def download(self, urls: Sequence[str]) -> str:
def download(self, urls: Sequence[str]) -> List[MaybeBlob]:
"""Download a batch of URLS synchronously."""
return asyncio.run(self.adownload(urls))
@@ -160,7 +160,7 @@ class AutoDownloadHandler(DownloadHandler):
must_redownload = [
(idx, url)
for idx, (url, blob) in enumerate(zip(urls, blobs))
if _is_javascript_required(blob.as_string())
if blob is not None and _is_javascript_required(blob.as_string())
]
if must_redownload:
indexes, urls_to_redownload = zip(*must_redownload)

View File

@@ -59,7 +59,7 @@ class DocReadingChain(Chain):
response = self.chain(
{"input_documents": _sub_docs, "question": question},
callbacks=run_manager.get_child(),
callbacks=run_manager.get_child() if run_manager else None,
)
summary_doc = Document(
page_content=response["output_text"],
@@ -83,7 +83,7 @@ class DocReadingChain(Chain):
results = await self.chain.acall(
{"input_documents": _sub_docs, "question": question},
callbacks=run_manager.get_child(),
callbacks=run_manager.get_child() if run_manager else None,
)
summary_doc = Document(
page_content=results["output_text"],

View File

@@ -167,7 +167,7 @@ async def _arun_searches(
a list of unique search results
"""
wrapper = serpapi.SerpAPIWrapper()
tasks = [wrapper.results(query) for query in queries]
tasks = [wrapper.aresults(query) for query in queries]
results = await asyncio.gather(*tasks)
finalized_results = []
@@ -218,7 +218,8 @@ class GenericSearcher(Chain):
1. Breaking a complex question into a series of simpler queries using an LLM.
2. Running the queries against a search engine.
3. Selecting the most relevant urls using an LLM (can be replaced with tf-idf or other models).
3. Selecting the most relevant urls using an LLM (can be replaced with tf-idf
or other models).
This chain is not meant to be used for questions requiring multiple hops to answer.
@@ -236,7 +237,7 @@ class GenericSearcher(Chain):
"""
query_generator: LLMChain
"""An LLM that is used to break down a complex question into a list of simpler queries."""
"""An LLM used to break down a complex question into a list of simpler queries."""
link_selection_model: Chain
"""An LLM that is used to select the most relevant urls from the search results."""
top_k_per_search: int = -1