diff --git a/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py b/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py index 8792a7036e7..c9462fb9b9f 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py +++ b/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py @@ -5,6 +5,7 @@ from langchain_core.callbacks.manager import Callbacks from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate, PromptTemplate +from langchain_core.runnables.config import RunnableConfig from langchain.chains import LLMChain from langchain.output_parsers.boolean import BooleanOutputParser @@ -45,14 +46,49 @@ class LLMChainFilter(BaseDocumentCompressor): ) -> Sequence[Document]: """Filter down documents based on their relevance to the query.""" filtered_docs = [] - for doc in documents: - _input = self.get_input(query, doc) - output_dict = self.llm_chain.invoke(_input, config={"callbacks": callbacks}) + + config = RunnableConfig(callbacks=callbacks) + outputs = zip( + self.llm_chain.batch( + [self.get_input(query, doc) for doc in documents], config=config + ), + documents, + ) + + for output_dict, doc in outputs: + include_doc = None output = output_dict[self.llm_chain.output_key] if self.llm_chain.prompt.output_parser is not None: include_doc = self.llm_chain.prompt.output_parser.parse(output) if include_doc: filtered_docs.append(doc) + + return filtered_docs + + async def acompress_documents( + self, + documents: Sequence[Document], + query: str, + callbacks: Optional[Callbacks] = None, + ) -> Sequence[Document]: + """Filter down documents based on their relevance to the query.""" + filtered_docs = [] + + config = RunnableConfig(callbacks=callbacks) + outputs = zip( + await self.llm_chain.abatch( + [self.get_input(query, doc) for doc in documents], config=config + ), + documents, + ) + for output_dict, doc in outputs: + include_doc = None + output = output_dict[self.llm_chain.output_key] + if self.llm_chain.prompt.output_parser is not None: + include_doc = self.llm_chain.prompt.output_parser.parse(output) + if include_doc: + filtered_docs.append(doc) + return filtered_docs @classmethod