From 750edfb440c5284589a3bfa905f549ac7ac08e10 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Fri, 16 Dec 2022 06:25:29 -0800 Subject: [PATCH] add optional collapse prompt (#358) --- langchain/chains/combine_documents/map_reduce.py | 16 +++++++++++++--- langchain/chains/qa_with_sources/__init__.py | 12 +++++++++++- langchain/chains/question_answering/__init__.py | 11 ++++++++++- langchain/chains/summarize/__init__.py | 11 ++++++++++- 4 files changed, 44 insertions(+), 6 deletions(-) diff --git a/langchain/chains/combine_documents/map_reduce.py b/langchain/chains/combine_documents/map_reduce.py index ff74bc4f858..8653ef1a7f8 100644 --- a/langchain/chains/combine_documents/map_reduce.py +++ b/langchain/chains/combine_documents/map_reduce.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Optional from pydantic import BaseModel, Extra, root_validator @@ -56,9 +56,12 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel): """Combining documents by mapping a chain over them, then combining results.""" llm_chain: LLMChain - """Chain to apply to each document individually..""" + """Chain to apply to each document individually.""" combine_document_chain: BaseCombineDocumentsChain """Chain to use to combine results of applying llm_chain to documents.""" + collapse_document_chain: Optional[BaseCombineDocumentsChain] = None + """Chain to use to collapse intermediary results if needed. + If None, will use the combine_document_chain.""" document_variable_name: str """The variable name in the llm_chain to put the documents in. If only one variable in the llm_chain, this need not be provided.""" @@ -90,6 +93,13 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel): ) return values + @property + def _collapse_chain(self) -> BaseCombineDocumentsChain: + if self.collapse_document_chain is not None: + return self.collapse_document_chain + else: + return self.combine_document_chain + def combine_docs( self, docs: List[Document], token_max: int = 3000, **kwargs: Any ) -> str: @@ -117,7 +127,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel): result_docs = [] for docs in new_result_doc_list: new_doc = _collapse_docs( - docs, self.combine_document_chain.combine_docs, **kwargs + docs, self._collapse_chain.combine_docs, **kwargs ) result_docs.append(new_doc) num_tokens = self.combine_document_chain.prompt_length( diff --git a/langchain/chains/qa_with_sources/__init__.py b/langchain/chains/qa_with_sources/__init__.py index 159abd7610d..82d93f50c80 100644 --- a/langchain/chains/qa_with_sources/__init__.py +++ b/langchain/chains/qa_with_sources/__init__.py @@ -1,5 +1,5 @@ """Load question answering with sources chains.""" -from typing import Any, Mapping, Protocol +from typing import Any, Mapping, Optional, Protocol from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain @@ -44,6 +44,7 @@ def _load_map_reduce_chain( document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT, combine_document_variable_name: str = "summaries", map_reduce_document_variable_name: str = "context", + collapse_prompt: Optional[BasePromptTemplate] = None, **kwargs: Any, ) -> MapReduceDocumentsChain: map_chain = LLMChain(llm=llm, prompt=question_prompt) @@ -53,10 +54,19 @@ def _load_map_reduce_chain( document_variable_name=combine_document_variable_name, document_prompt=document_prompt, ) + if collapse_prompt is None: + collapse_chain = None + else: + collapse_chain = StuffDocumentsChain( + llm_chain=LLMChain(llm=llm, prompt=collapse_prompt), + document_variable_name=combine_document_variable_name, + document_prompt=document_prompt, + ) return MapReduceDocumentsChain( llm_chain=map_chain, combine_document_chain=combine_document_chain, document_variable_name=map_reduce_document_variable_name, + collapse_document_chain=collapse_chain, **kwargs, ) diff --git a/langchain/chains/question_answering/__init__.py b/langchain/chains/question_answering/__init__.py index 1591054caa5..9883e06840d 100644 --- a/langchain/chains/question_answering/__init__.py +++ b/langchain/chains/question_answering/__init__.py @@ -1,5 +1,5 @@ """Load question answering chains.""" -from typing import Any, Mapping, Protocol +from typing import Any, Mapping, Optional, Protocol from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain @@ -41,6 +41,7 @@ def _load_map_reduce_chain( combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT, combine_document_variable_name: str = "summaries", map_reduce_document_variable_name: str = "context", + collapse_prompt: Optional[BasePromptTemplate] = None, **kwargs: Any, ) -> MapReduceDocumentsChain: map_chain = LLMChain(llm=llm, prompt=question_prompt) @@ -49,10 +50,18 @@ def _load_map_reduce_chain( combine_document_chain = StuffDocumentsChain( llm_chain=reduce_chain, document_variable_name=combine_document_variable_name ) + if collapse_prompt is None: + collapse_chain = None + else: + collapse_chain = StuffDocumentsChain( + llm_chain=LLMChain(llm=llm, prompt=collapse_prompt), + document_variable_name=combine_document_variable_name, + ) return MapReduceDocumentsChain( llm_chain=map_chain, combine_document_chain=combine_document_chain, document_variable_name=map_reduce_document_variable_name, + collapse_document_chain=collapse_chain, **kwargs, ) diff --git a/langchain/chains/summarize/__init__.py b/langchain/chains/summarize/__init__.py index ab90c70b572..e613ff8dbea 100644 --- a/langchain/chains/summarize/__init__.py +++ b/langchain/chains/summarize/__init__.py @@ -1,5 +1,5 @@ """Load summarizing chains.""" -from typing import Any, Mapping, Protocol +from typing import Any, Mapping, Optional, Protocol from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain @@ -37,6 +37,7 @@ def _load_map_reduce_chain( combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT, combine_document_variable_name: str = "text", map_reduce_document_variable_name: str = "text", + collapse_prompt: Optional[BasePromptTemplate] = None, **kwargs: Any, ) -> MapReduceDocumentsChain: map_chain = LLMChain(llm=llm, prompt=map_prompt) @@ -45,10 +46,18 @@ def _load_map_reduce_chain( combine_document_chain = StuffDocumentsChain( llm_chain=reduce_chain, document_variable_name=combine_document_variable_name ) + if collapse_prompt is None: + collapse_chain = None + else: + collapse_chain = StuffDocumentsChain( + llm_chain=LLMChain(llm=llm, prompt=collapse_prompt), + document_variable_name=combine_document_variable_name, + ) return MapReduceDocumentsChain( llm_chain=map_chain, combine_document_chain=combine_document_chain, document_variable_name=map_reduce_document_variable_name, + collapse_document_chain=collapse_chain, **kwargs, )