From fc3c2c4406aed64f4183d4a08ad43822d358c45d Mon Sep 17 00:00:00 2001 From: Jan Philipp Harries Date: Tue, 2 May 2023 06:23:13 +0200 Subject: [PATCH] Async Support for LLMChainExtractor (new) (#3780) @vowelparrot @hwchase17 Here a new implementation of `acompress_documents` for `LLMChainExtractor ` without changes to the sync-version, as you suggested in #3587 / [Async Support for LLMChainExtractor](https://github.com/hwchase17/langchain/pull/3587) . I created a new PR to avoid cluttering history with reverted commits, hope that is the right way. Happy for any improvements/suggestions. (PS: I also tried an alternative implementation with a nested helper function like ``` python async def acompress_documents_old( self, documents: Sequence[Document], query: str ) -> Sequence[Document]: """Compress page content of raw documents.""" async def _compress_concurrently(doc): _input = self.get_input(query, doc) output = await self.llm_chain.apredict_and_parse(**_input) return Document(page_content=output, metadata=doc.metadata) outputs=await asyncio.gather(*[_compress_concurrently(doc) for doc in documents]) compressed_docs=list(filter(lambda x: len(x.page_content)>0,outputs)) return compressed_docs ``` But in the end I found the commited version to be better readable and more "canonical" - hope you agree. --- .../document_compressors/chain_extract.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/langchain/retrievers/document_compressors/chain_extract.py b/langchain/retrievers/document_compressors/chain_extract.py index 71b4bc13d55..9b7947c4131 100644 --- a/langchain/retrievers/document_compressors/chain_extract.py +++ b/langchain/retrievers/document_compressors/chain_extract.py @@ -1,6 +1,7 @@ """DocumentFilter that uses an LLM chain to extract the relevant parts of documents.""" from __future__ import annotations +import asyncio from typing import Any, Callable, Dict, Optional, Sequence from langchain import LLMChain, PromptTemplate @@ -62,7 +63,21 @@ class LLMChainExtractor(BaseDocumentCompressor): async def acompress_documents( self, documents: Sequence[Document], query: str ) -> Sequence[Document]: - raise NotImplementedError + """Compress page content of raw documents asynchronously.""" + outputs = await asyncio.gather( + *[ + self.llm_chain.apredict_and_parse(**self.get_input(query, doc)) + for doc in documents + ] + ) + compressed_docs = [] + for i, doc in enumerate(documents): + if len(outputs[i]) == 0: + continue + compressed_docs.append( + Document(page_content=outputs[i], metadata=doc.metadata) + ) + return compressed_docs @classmethod def from_llm(