diff --git a/libs/langchain/langchain/chains/combine_documents/reduce.py b/libs/langchain/langchain/chains/combine_documents/reduce.py index 00e3efd02e0..a55c067d168 100644 --- a/libs/langchain/langchain/chains/combine_documents/reduce.py +++ b/libs/langchain/langchain/chains/combine_documents/reduce.py @@ -200,6 +200,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): """The maximum number of tokens to group documents into. For example, if set to 3000 then documents will be grouped into chunks of no greater than 3000 tokens before trying to combine them into a smaller chunk.""" + collapse_max_retries: Optional[int] = None + """The maximum number of retries to collapse documents to fit token_max. + If None, it will keep trying to collapse documents to fit token_max. + Otherwise, after it reaches the max number, it will throw an error""" class Config: """Configuration for this pydantic object.""" @@ -289,6 +293,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): ) _token_max = token_max or self.token_max + retries: int = 0 while num_tokens is not None and num_tokens > _token_max: new_result_doc_list = split_list_of_docs( result_docs, length_func, _token_max, **kwargs @@ -298,6 +303,12 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): new_doc = collapse_docs(docs, _collapse_docs_func, **kwargs) result_docs.append(new_doc) num_tokens = length_func(result_docs, **kwargs) + retries += 1 + if self.collapse_max_retries and retries == self.collapse_max_retries: + raise ValueError( + f"Exceed {self.collapse_max_retries} tries to \ + collapse document to {_token_max} tokens." + ) return result_docs, {} async def _acollapse( @@ -317,6 +328,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): ) _token_max = token_max or self.token_max + retries: int = 0 while num_tokens is not None and num_tokens > _token_max: new_result_doc_list = split_list_of_docs( result_docs, length_func, _token_max, **kwargs @@ -326,6 +338,12 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): new_doc = await acollapse_docs(docs, _collapse_docs_func, **kwargs) result_docs.append(new_doc) num_tokens = length_func(result_docs, **kwargs) + retries += 1 + if self.collapse_max_retries and retries == self.collapse_max_retries: + raise ValueError( + f"Exceed {self.collapse_max_retries} tries to \ + collapse document to {_token_max} tokens." + ) return result_docs, {} @property diff --git a/libs/langchain/langchain/chains/summarize/__init__.py b/libs/langchain/langchain/chains/summarize/__init__.py index 9bc8b8118bd..3f6068a0f26 100644 --- a/libs/langchain/langchain/chains/summarize/__init__.py +++ b/libs/langchain/langchain/chains/summarize/__init__.py @@ -52,6 +52,8 @@ def _load_map_reduce_chain( verbose: Optional[bool] = None, token_max: int = 3000, callbacks: Callbacks = None, + *, + collapse_max_retries: Optional[int] = None, **kwargs: Any, ) -> MapReduceDocumentsChain: map_chain = LLMChain( @@ -92,6 +94,7 @@ def _load_map_reduce_chain( token_max=token_max, verbose=verbose, callbacks=callbacks, + collapse_max_retries=collapse_max_retries, ) return MapReduceDocumentsChain( llm_chain=map_chain,