mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-24 05:50:18 +00:00
x
This commit is contained in:
@@ -21,7 +21,10 @@ from typing import (
|
||||
|
||||
from langchain import LLMChain, PromptTemplate
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForChainRun,
|
||||
AsyncCallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.schema import BaseOutputParser
|
||||
|
||||
@@ -191,6 +194,43 @@ class MultiSelectChain(Chain):
|
||||
"selected": selected,
|
||||
}
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
choices = inputs["choices"]
|
||||
question = inputs["question"]
|
||||
columns = inputs.get("columns", None)
|
||||
|
||||
selected = []
|
||||
# TODO(): Balance choices into equal batches with constraint dependent
|
||||
# on context window and prompt
|
||||
max_choices = 30
|
||||
|
||||
for choice_batch in batch(choices, max_choices):
|
||||
records_with_ids = [
|
||||
{**record, "id": idx} for idx, record in enumerate(choice_batch)
|
||||
]
|
||||
records_str = _write_records_to_string(
|
||||
records_with_ids, columns=columns, delimiter="|"
|
||||
)
|
||||
|
||||
indexes = cast(
|
||||
List[int],
|
||||
await self.llm_chain.apredict_and_parse(
|
||||
records=records_str,
|
||||
question=question,
|
||||
callbacks=run_manager.get_child(),
|
||||
),
|
||||
)
|
||||
valid_indexes = [idx for idx in indexes if 0 <= idx < len(choice_batch)]
|
||||
selected.extend(choice_batch[i] for i in valid_indexes)
|
||||
|
||||
return {
|
||||
"selected": selected,
|
||||
}
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
"""Return the chain type."""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Any, Union, Literal, List, Dict
|
||||
from typing import Optional, Any, Union, Literal, List, Dict, Mapping
|
||||
import itertools
|
||||
|
||||
|
||||
@@ -93,7 +93,7 @@ class Research(Chain):
|
||||
question = inputs["question"]
|
||||
search_results = await self.searcher.acall({"question": question})
|
||||
urls = search_results["urls"]
|
||||
blobs = self.downloader.download(urls)
|
||||
blobs = await self.downloader.adownload(urls)
|
||||
parser = MarkdownifyHTMLParser()
|
||||
docs = itertools.chain.from_iterable(parser.lazy_parse(blob) for blob in blobs)
|
||||
inputs = [{"doc": doc, "question": question} for doc in docs]
|
||||
@@ -112,9 +112,10 @@ class Research(Chain):
|
||||
qa_chain: LLMChain,
|
||||
top_k_per_search: int = -1,
|
||||
max_concurrency: int = 1,
|
||||
max_num_pages_per_doc: int = 100,
|
||||
max_num_pages_per_doc: int = 5,
|
||||
text_splitter: Union[TextSplitter, Literal["recursive"]] = "recursive",
|
||||
download_handler: Union[DownloadHandler, Literal["auto"]] = "auto",
|
||||
text_splitter_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> Research:
|
||||
"""Helper to create a research chain from standard llm related components.
|
||||
|
||||
@@ -130,13 +131,16 @@ class Research(Chain):
|
||||
Provide either a download handler or the name of a
|
||||
download handler.
|
||||
- "auto" swaps between using requests and playwright
|
||||
text_splitter_kwargs: The keyword arguments to pass to the text splitter.
|
||||
Only use when providing a text splitter as string.
|
||||
|
||||
Returns:
|
||||
A research chain.
|
||||
"""
|
||||
if isinstance(text_splitter, str):
|
||||
if text_splitter == "recursive":
|
||||
_text_splitter = RecursiveCharacterTextSplitter()
|
||||
_text_splitter_kwargs = text_splitter_kwargs or {}
|
||||
_text_splitter = RecursiveCharacterTextSplitter(**_text_splitter_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Invalid text splitter: {text_splitter}")
|
||||
elif isinstance(text_splitter, TextSplitter):
|
||||
|
||||
@@ -14,9 +14,10 @@ Downloading is batched by default to allow efficient parallelization.
|
||||
import abc
|
||||
import asyncio
|
||||
import mimetypes
|
||||
from bs4 import BeautifulSoup
|
||||
from typing import Sequence, List, Any, Optional
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from langchain.document_loaders import WebBaseLoader
|
||||
from langchain.document_loaders.blob_loaders import Blob
|
||||
|
||||
@@ -61,8 +62,7 @@ class PlaywrightDownloadHandler(DownloadHandler):
|
||||
|
||||
def download(self, urls: Sequence[str]) -> List[Blob]:
|
||||
"""Download list of urls synchronously."""
|
||||
# Implement using a threadpool or using playwright API if it supports it
|
||||
raise NotImplementedError()
|
||||
return asyncio.run(self.adownload(urls))
|
||||
|
||||
async def _download(self, browser: Any, url: str) -> str:
|
||||
"""Download a url asynchronously using playwright."""
|
||||
@@ -92,14 +92,13 @@ class PlaywrightDownloadHandler(DownloadHandler):
|
||||
|
||||
|
||||
class RequestsDownloadHandler(DownloadHandler):
|
||||
def __init__(self, web_downloader: WebBaseLoader) -> None:
|
||||
def __init__(self, web_downloader: Optional[WebBaseLoader] = None) -> None:
|
||||
"""Initialize the requests download handler."""
|
||||
self.web_downloader = web_downloader
|
||||
self.web_downloader = web_downloader or WebBaseLoader(web_path=[])
|
||||
|
||||
def download(self, urls: Sequence[str]) -> str:
|
||||
"""Download a batch of URLS synchronously."""
|
||||
# Implement with threadpool.
|
||||
raise NotImplementedError()
|
||||
return asyncio.run(self.adownload(urls))
|
||||
|
||||
async def adownload(self, urls: Sequence[str]) -> List[Blob]:
|
||||
"""Download a batch of urls asynchronously using playwright."""
|
||||
@@ -138,11 +137,16 @@ class AutoDownloadHandler(DownloadHandler):
|
||||
must_redownload = [
|
||||
(idx, url)
|
||||
for idx, (url, blob) in enumerate(zip(urls, blobs))
|
||||
if _is_javascript_required(blob.data)
|
||||
if _is_javascript_required(blob.as_string())
|
||||
]
|
||||
indexes, urls_to_redownload = zip(*must_redownload)
|
||||
new_blobs = await self.playwright_downloader.adownload(urls_to_redownload)
|
||||
if must_redownload:
|
||||
indexes, urls_to_redownload = zip(*must_redownload)
|
||||
new_blobs = await self.playwright_downloader.adownload(urls_to_redownload)
|
||||
|
||||
for idx, blob in zip(indexes, new_blobs):
|
||||
blobs[idx] = blob
|
||||
for idx, blob in zip(indexes, new_blobs):
|
||||
blobs[idx] = blob
|
||||
return blobs
|
||||
|
||||
def download(self, urls: Sequence[str]) -> List[Blob]:
|
||||
"""Download a batch of URLs synchronously."""
|
||||
return asyncio.run(self.adownload(urls))
|
||||
|
||||
@@ -77,6 +77,7 @@ class DocReadingChain(Chain):
|
||||
_sub_docs = sub_docs[: self.max_num_docs]
|
||||
else:
|
||||
_sub_docs = sub_docs
|
||||
|
||||
results = await self.chain.acall(
|
||||
{"input_documents": _sub_docs, "question": question},
|
||||
callbacks=run_manager.get_child(),
|
||||
|
||||
0
tests/unit_tests/chains/research/__init__.py
Normal file
0
tests/unit_tests/chains/research/__init__.py
Normal file
38
tests/unit_tests/chains/research/test_downloader.py
Normal file
38
tests/unit_tests/chains/research/test_downloader.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Tests for the downloader."""
|
||||
from langchain.chains.research.fetch import (
|
||||
AutoDownloadHandler,
|
||||
_is_javascript_required,
|
||||
RequestsDownloadHandler,
|
||||
PlaywrightDownloadHandler,
|
||||
)
|
||||
|
||||
|
||||
def test_is_javascript_required():
|
||||
"""Check whether a given page should be re-downloaded with javascript executed."""
|
||||
assert not _is_javascript_required(
|
||||
"""
|
||||
<html>
|
||||
<body>
|
||||
<p>Check whether javascript is required.</p>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
)
|
||||
|
||||
assert _is_javascript_required(
|
||||
"""
|
||||
<html>
|
||||
<body>
|
||||
<script>
|
||||
console.log("Javascript is required.");
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def test_requests_handler():
|
||||
"""Test that the requests handler is working."""
|
||||
handler = RequestsDownloadHandler()
|
||||
fetch = handler.download(["https://www.google.com"])
|
||||
Reference in New Issue
Block a user