This commit is contained in:
Eugene Yurtsev
2023-06-02 11:23:38 -04:00
parent 679eb9f14f
commit 39a2d2511d
6 changed files with 104 additions and 17 deletions

View File

@@ -21,7 +21,10 @@ from typing import (
from langchain import LLMChain, PromptTemplate
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.callbacks.manager import (
CallbackManagerForChainRun,
AsyncCallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.schema import BaseOutputParser
@@ -191,6 +194,43 @@ class MultiSelectChain(Chain):
"selected": selected,
}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
choices = inputs["choices"]
question = inputs["question"]
columns = inputs.get("columns", None)
selected = []
# TODO(): Balance choices into equal batches with constraint dependent
# on context window and prompt
max_choices = 30
for choice_batch in batch(choices, max_choices):
records_with_ids = [
{**record, "id": idx} for idx, record in enumerate(choice_batch)
]
records_str = _write_records_to_string(
records_with_ids, columns=columns, delimiter="|"
)
indexes = cast(
List[int],
await self.llm_chain.apredict_and_parse(
records=records_str,
question=question,
callbacks=run_manager.get_child(),
),
)
valid_indexes = [idx for idx in indexes if 0 <= idx < len(choice_batch)]
selected.extend(choice_batch[i] for i in valid_indexes)
return {
"selected": selected,
}
@property
def _chain_type(self) -> str:
"""Return the chain type."""

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Optional, Any, Union, Literal, List, Dict
from typing import Optional, Any, Union, Literal, List, Dict, Mapping
import itertools
@@ -93,7 +93,7 @@ class Research(Chain):
question = inputs["question"]
search_results = await self.searcher.acall({"question": question})
urls = search_results["urls"]
blobs = self.downloader.download(urls)
blobs = await self.downloader.adownload(urls)
parser = MarkdownifyHTMLParser()
docs = itertools.chain.from_iterable(parser.lazy_parse(blob) for blob in blobs)
inputs = [{"doc": doc, "question": question} for doc in docs]
@@ -112,9 +112,10 @@ class Research(Chain):
qa_chain: LLMChain,
top_k_per_search: int = -1,
max_concurrency: int = 1,
max_num_pages_per_doc: int = 100,
max_num_pages_per_doc: int = 5,
text_splitter: Union[TextSplitter, Literal["recursive"]] = "recursive",
download_handler: Union[DownloadHandler, Literal["auto"]] = "auto",
text_splitter_kwargs: Optional[Mapping[str, Any]] = None,
) -> Research:
"""Helper to create a research chain from standard llm related components.
@@ -130,13 +131,16 @@ class Research(Chain):
Provide either a download handler or the name of a
download handler.
- "auto" swaps between using requests and playwright
text_splitter_kwargs: The keyword arguments to pass to the text splitter.
Only use when providing a text splitter as string.
Returns:
A research chain.
"""
if isinstance(text_splitter, str):
if text_splitter == "recursive":
_text_splitter = RecursiveCharacterTextSplitter()
_text_splitter_kwargs = text_splitter_kwargs or {}
_text_splitter = RecursiveCharacterTextSplitter(**_text_splitter_kwargs)
else:
raise ValueError(f"Invalid text splitter: {text_splitter}")
elif isinstance(text_splitter, TextSplitter):

View File

@@ -14,9 +14,10 @@ Downloading is batched by default to allow efficient parallelization.
import abc
import asyncio
import mimetypes
from bs4 import BeautifulSoup
from typing import Sequence, List, Any, Optional
from bs4 import BeautifulSoup
from langchain.document_loaders import WebBaseLoader
from langchain.document_loaders.blob_loaders import Blob
@@ -61,8 +62,7 @@ class PlaywrightDownloadHandler(DownloadHandler):
def download(self, urls: Sequence[str]) -> List[Blob]:
"""Download list of urls synchronously."""
# Implement using a threadpool or using playwright API if it supports it
raise NotImplementedError()
return asyncio.run(self.adownload(urls))
async def _download(self, browser: Any, url: str) -> str:
"""Download a url asynchronously using playwright."""
@@ -92,14 +92,13 @@ class PlaywrightDownloadHandler(DownloadHandler):
class RequestsDownloadHandler(DownloadHandler):
def __init__(self, web_downloader: WebBaseLoader) -> None:
def __init__(self, web_downloader: Optional[WebBaseLoader] = None) -> None:
"""Initialize the requests download handler."""
self.web_downloader = web_downloader
self.web_downloader = web_downloader or WebBaseLoader(web_path=[])
def download(self, urls: Sequence[str]) -> str:
"""Download a batch of URLS synchronously."""
# Implement with threadpool.
raise NotImplementedError()
return asyncio.run(self.adownload(urls))
async def adownload(self, urls: Sequence[str]) -> List[Blob]:
"""Download a batch of urls asynchronously using playwright."""
@@ -138,11 +137,16 @@ class AutoDownloadHandler(DownloadHandler):
must_redownload = [
(idx, url)
for idx, (url, blob) in enumerate(zip(urls, blobs))
if _is_javascript_required(blob.data)
if _is_javascript_required(blob.as_string())
]
indexes, urls_to_redownload = zip(*must_redownload)
new_blobs = await self.playwright_downloader.adownload(urls_to_redownload)
if must_redownload:
indexes, urls_to_redownload = zip(*must_redownload)
new_blobs = await self.playwright_downloader.adownload(urls_to_redownload)
for idx, blob in zip(indexes, new_blobs):
blobs[idx] = blob
for idx, blob in zip(indexes, new_blobs):
blobs[idx] = blob
return blobs
def download(self, urls: Sequence[str]) -> List[Blob]:
"""Download a batch of URLs synchronously."""
return asyncio.run(self.adownload(urls))

View File

@@ -77,6 +77,7 @@ class DocReadingChain(Chain):
_sub_docs = sub_docs[: self.max_num_docs]
else:
_sub_docs = sub_docs
results = await self.chain.acall(
{"input_documents": _sub_docs, "question": question},
callbacks=run_manager.get_child(),

View File

@@ -0,0 +1,38 @@
"""Tests for the downloader."""
from langchain.chains.research.fetch import (
AutoDownloadHandler,
_is_javascript_required,
RequestsDownloadHandler,
PlaywrightDownloadHandler,
)
def test_is_javascript_required():
"""Check whether a given page should be re-downloaded with javascript executed."""
assert not _is_javascript_required(
"""
<html>
<body>
<p>Check whether javascript is required.</p>
</body>
</html>
"""
)
assert _is_javascript_required(
"""
<html>
<body>
<script>
console.log("Javascript is required.");
</script>
</body>
</html>
"""
)
def test_requests_handler():
"""Test that the requests handler is working."""
handler = RequestsDownloadHandler()
fetch = handler.download(["https://www.google.com"])