From 0908b01cb2c5f1a2401fffe821ada2b4ba956830 Mon Sep 17 00:00:00 2001 From: Philippe PRADOS Date: Tue, 11 Jun 2024 16:55:40 +0200 Subject: [PATCH] langchain[minor]: Add native async implementation to LLMFilter, add concurrency to both sync and async paths (#22739) Thank you for contributing to LangChain! - [ ] **PR title**: "langchain: Fix chain_filter.py to be compatible with async" - [ ] **PR message**: - **Description:** chain_filter is not compatible with async. - **Twitter handle:** pprados - [X ] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ --------- Signed-off-by: zhangwangda Co-authored-by: Prakul Co-authored-by: Lei Zhang Co-authored-by: Gin Co-authored-by: wangda <38549158+daziz@users.noreply.github.com> Co-authored-by: Max Mulatz --- .../document_compressors/chain_filter.py | 42 +++++++++++++++++-- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py b/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py index 8792a7036e7..c9462fb9b9f 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py +++ b/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py @@ -5,6 +5,7 @@ from langchain_core.callbacks.manager import Callbacks from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate, PromptTemplate +from langchain_core.runnables.config import RunnableConfig from langchain.chains import LLMChain from langchain.output_parsers.boolean import BooleanOutputParser @@ -45,14 +46,49 @@ class LLMChainFilter(BaseDocumentCompressor): ) -> Sequence[Document]: """Filter down documents based on their relevance to the query.""" filtered_docs = [] - for doc in documents: - _input = self.get_input(query, doc) - output_dict = self.llm_chain.invoke(_input, config={"callbacks": callbacks}) + + config = RunnableConfig(callbacks=callbacks) + outputs = zip( + self.llm_chain.batch( + [self.get_input(query, doc) for doc in documents], config=config + ), + documents, + ) + + for output_dict, doc in outputs: + include_doc = None output = output_dict[self.llm_chain.output_key] if self.llm_chain.prompt.output_parser is not None: include_doc = self.llm_chain.prompt.output_parser.parse(output) if include_doc: filtered_docs.append(doc) + + return filtered_docs + + async def acompress_documents( + self, + documents: Sequence[Document], + query: str, + callbacks: Optional[Callbacks] = None, + ) -> Sequence[Document]: + """Filter down documents based on their relevance to the query.""" + filtered_docs = [] + + config = RunnableConfig(callbacks=callbacks) + outputs = zip( + await self.llm_chain.abatch( + [self.get_input(query, doc) for doc in documents], config=config + ), + documents, + ) + for output_dict, doc in outputs: + include_doc = None + output = output_dict[self.llm_chain.output_key] + if self.llm_chain.prompt.output_parser is not None: + include_doc = self.llm_chain.prompt.output_parser.parse(output) + if include_doc: + filtered_docs.append(doc) + return filtered_docs @classmethod