mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-24 05:50:18 +00:00
q
This commit is contained in:
@@ -132,8 +132,17 @@ def _write_records_to_string(
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def batch(iterable: Iterable[T], size) -> Iterator[List[T]]:
|
||||
"""Batch an iterable into chunks of size `size`."""
|
||||
def _batch(iterable: Iterable[T], size: int) -> Iterator[List[T]]:
|
||||
"""Batch an iterable into chunks of size `size`.
|
||||
|
||||
Args:
|
||||
iterable: the iterable to batch
|
||||
size: the size of each batch
|
||||
|
||||
Returns:
|
||||
iterator over batches of size `size` except for last batch which will be up
|
||||
to size `size`
|
||||
"""
|
||||
iterator = iter(iterable)
|
||||
while True:
|
||||
batch = list(islice(iterator, size))
|
||||
@@ -167,12 +176,12 @@ class MultiSelectChain(Chain):
|
||||
question = inputs["question"]
|
||||
columns = inputs.get("columns", None)
|
||||
|
||||
selected = []
|
||||
selected: List[Mapping[str, Any]] = []
|
||||
# TODO(): Balance choices into equal batches with constraint dependent
|
||||
# on context window and prompt
|
||||
max_choices = 30
|
||||
|
||||
for choice_batch in batch(choices, max_choices):
|
||||
for choice_batch in _batch(choices, max_choices):
|
||||
records_with_ids = [
|
||||
{**record, "id": idx} for idx, record in enumerate(choice_batch)
|
||||
]
|
||||
@@ -185,7 +194,7 @@ class MultiSelectChain(Chain):
|
||||
self.llm_chain.predict_and_parse(
|
||||
records=records_str,
|
||||
question=question,
|
||||
callbacks=run_manager.get_child(),
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
),
|
||||
)
|
||||
valid_indexes = [idx for idx in indexes if 0 <= idx < len(choice_batch)]
|
||||
@@ -204,12 +213,12 @@ class MultiSelectChain(Chain):
|
||||
question = inputs["question"]
|
||||
columns = inputs.get("columns", None)
|
||||
|
||||
selected = []
|
||||
selected: List[Mapping[str, Any]] = []
|
||||
# TODO(): Balance choices into equal batches with constraint dependent
|
||||
# on context window and prompt
|
||||
max_choices = 30
|
||||
|
||||
for choice_batch in batch(choices, max_choices):
|
||||
for choice_batch in _batch(choices, max_choices):
|
||||
records_with_ids = [
|
||||
{**record, "id": idx} for idx, record in enumerate(choice_batch)
|
||||
]
|
||||
@@ -222,7 +231,7 @@ class MultiSelectChain(Chain):
|
||||
await self.llm_chain.apredict_and_parse(
|
||||
records=records_str,
|
||||
question=question,
|
||||
callbacks=run_manager.get_child(),
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
),
|
||||
)
|
||||
valid_indexes = [idx for idx in indexes if 0 <= idx < len(choice_batch)]
|
||||
|
||||
@@ -36,7 +36,8 @@ class Research(Chain):
|
||||
|
||||
Limitations:
|
||||
* This research chain only implements a single hop at the moment; i.e.,
|
||||
it goes from the questions to a list of URLs to documents to compiling answers.
|
||||
it goes from the questions to a list of URLs to documents to compiling
|
||||
answers.
|
||||
* The reader chain needs to match the task. For example, if using a QA refine
|
||||
chain, a task of collecting a list of entries from a long document will
|
||||
fail because the QA refine chain is not designed to handle such a task.
|
||||
@@ -145,7 +146,8 @@ class Research(Chain):
|
||||
top_k_per_search: The number of documents to return per search.
|
||||
max_concurrency: The maximum number of concurrent reads.
|
||||
max_num_pages_per_doc: The maximum number of pages to read per document.
|
||||
text_splitter: The text splitter to use to split the document into smaller chunks.
|
||||
text_splitter: The text splitter to use to split the document into
|
||||
smaller chunks.
|
||||
download_handler: The download handler to use to download the documents.
|
||||
Provide either a download handler or the name of a
|
||||
download handler.
|
||||
|
||||
@@ -70,7 +70,7 @@ class PlaywrightDownloadHandler(DownloadHandler):
|
||||
"""
|
||||
self.timeout = timeout
|
||||
|
||||
def download(self, urls: Sequence[str]) -> List[Blob]:
|
||||
def download(self, urls: Sequence[str]) -> List[MaybeBlob]:
|
||||
"""Download list of urls synchronously."""
|
||||
return asyncio.run(self.adownload(urls))
|
||||
|
||||
@@ -87,7 +87,7 @@ class PlaywrightDownloadHandler(DownloadHandler):
|
||||
html_content = None
|
||||
return html_content
|
||||
|
||||
async def adownload(self, urls: Sequence[str]) -> List[Blob]:
|
||||
async def adownload(self, urls: Sequence[str]) -> List[MaybeBlob]:
|
||||
"""Download a batch of URLs asynchronously using playwright.
|
||||
|
||||
Args:
|
||||
@@ -112,7 +112,7 @@ class RequestsDownloadHandler(DownloadHandler):
|
||||
"""Initialize the requests download handler."""
|
||||
self.web_downloader = web_downloader or WebBaseLoader(web_path=[])
|
||||
|
||||
def download(self, urls: Sequence[str]) -> str:
|
||||
def download(self, urls: Sequence[str]) -> List[MaybeBlob]:
|
||||
"""Download a batch of URLS synchronously."""
|
||||
return asyncio.run(self.adownload(urls))
|
||||
|
||||
@@ -160,7 +160,7 @@ class AutoDownloadHandler(DownloadHandler):
|
||||
must_redownload = [
|
||||
(idx, url)
|
||||
for idx, (url, blob) in enumerate(zip(urls, blobs))
|
||||
if _is_javascript_required(blob.as_string())
|
||||
if blob is not None and _is_javascript_required(blob.as_string())
|
||||
]
|
||||
if must_redownload:
|
||||
indexes, urls_to_redownload = zip(*must_redownload)
|
||||
|
||||
@@ -59,7 +59,7 @@ class DocReadingChain(Chain):
|
||||
|
||||
response = self.chain(
|
||||
{"input_documents": _sub_docs, "question": question},
|
||||
callbacks=run_manager.get_child(),
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
summary_doc = Document(
|
||||
page_content=response["output_text"],
|
||||
@@ -83,7 +83,7 @@ class DocReadingChain(Chain):
|
||||
|
||||
results = await self.chain.acall(
|
||||
{"input_documents": _sub_docs, "question": question},
|
||||
callbacks=run_manager.get_child(),
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
summary_doc = Document(
|
||||
page_content=results["output_text"],
|
||||
|
||||
@@ -167,7 +167,7 @@ async def _arun_searches(
|
||||
a list of unique search results
|
||||
"""
|
||||
wrapper = serpapi.SerpAPIWrapper()
|
||||
tasks = [wrapper.results(query) for query in queries]
|
||||
tasks = [wrapper.aresults(query) for query in queries]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
finalized_results = []
|
||||
@@ -218,7 +218,8 @@ class GenericSearcher(Chain):
|
||||
|
||||
1. Breaking a complex question into a series of simpler queries using an LLM.
|
||||
2. Running the queries against a search engine.
|
||||
3. Selecting the most relevant urls using an LLM (can be replaced with tf-idf or other models).
|
||||
3. Selecting the most relevant urls using an LLM (can be replaced with tf-idf
|
||||
or other models).
|
||||
|
||||
This chain is not meant to be used for questions requiring multiple hops to answer.
|
||||
|
||||
@@ -236,7 +237,7 @@ class GenericSearcher(Chain):
|
||||
"""
|
||||
|
||||
query_generator: LLMChain
|
||||
"""An LLM that is used to break down a complex question into a list of simpler queries."""
|
||||
"""An LLM used to break down a complex question into a list of simpler queries."""
|
||||
link_selection_model: Chain
|
||||
"""An LLM that is used to select the most relevant urls from the search results."""
|
||||
top_k_per_search: int = -1
|
||||
|
||||
Reference in New Issue
Block a user