From 8410c6a747b4ba99061ee98456efef238d68a6f5 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Wed, 5 Jul 2023 12:09:25 -0400 Subject: [PATCH] add token max parameter (#7204) --- .../chains/combine_documents/map_reduce.py | 6 ++--- langchain/chains/combine_documents/reduce.py | 22 ++++++++++++------- langchain/chains/qa_with_sources/loading.py | 3 +++ .../chains/question_answering/__init__.py | 3 +++ langchain/chains/summarize/__init__.py | 3 +++ 5 files changed, 26 insertions(+), 11 deletions(-) diff --git a/langchain/chains/combine_documents/map_reduce.py b/langchain/chains/combine_documents/map_reduce.py index e4dbec3e858..9afed3e7e80 100644 --- a/langchain/chains/combine_documents/map_reduce.py +++ b/langchain/chains/combine_documents/map_reduce.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple from pydantic import Extra, root_validator @@ -198,7 +198,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): def combine_docs( self, docs: List[Document], - token_max: int = 3000, + token_max: Optional[int] = None, callbacks: Callbacks = None, **kwargs: Any, ) -> Tuple[str, dict]: @@ -229,7 +229,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): async def acombine_docs( self, docs: List[Document], - token_max: int = 3000, + token_max: Optional[int] = None, callbacks: Callbacks = None, **kwargs: Any, ) -> Tuple[str, dict]: diff --git a/langchain/chains/combine_documents/reduce.py b/langchain/chains/combine_documents/reduce.py index 9458c394910..7ecde047cbb 100644 --- a/langchain/chains/combine_documents/reduce.py +++ b/langchain/chains/combine_documents/reduce.py @@ -152,6 +152,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): """Chain to use to collapse documents if needed until they can all fit. If None, will use the combine_documents_chain. This is typically a StuffDocumentsChain.""" + token_max: int = 3000 + """The maximum number of tokens to group documents into. For example, if + set to 3000 then documents will be grouped into chunks of no greater than + 3000 tokens before trying to combine them into a smaller chunk.""" class Config: """Configuration for this pydantic object.""" @@ -169,7 +173,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): def combine_docs( self, docs: List[Document], - token_max: int = 3000, + token_max: Optional[int] = None, callbacks: Callbacks = None, **kwargs: Any, ) -> Tuple[str, dict]: @@ -198,7 +202,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): async def acombine_docs( self, docs: List[Document], - token_max: int = 3000, + token_max: Optional[int] = None, callbacks: Callbacks = None, **kwargs: Any, ) -> Tuple[str, dict]: @@ -227,7 +231,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): def _collapse( self, docs: List[Document], - token_max: int = 3000, + token_max: Optional[int] = None, callbacks: Callbacks = None, **kwargs: Any, ) -> Tuple[List[Document], dict]: @@ -240,9 +244,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): input_documents=docs, callbacks=callbacks, **kwargs ) - while num_tokens is not None and num_tokens > token_max: + _token_max = token_max or self.token_max + while num_tokens is not None and num_tokens > _token_max: new_result_doc_list = _split_list_of_docs( - result_docs, length_func, token_max, **kwargs + result_docs, length_func, _token_max, **kwargs ) result_docs = [] for docs in new_result_doc_list: @@ -254,7 +259,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): async def _acollapse( self, docs: List[Document], - token_max: int = 3000, + token_max: Optional[int] = None, callbacks: Callbacks = None, **kwargs: Any, ) -> Tuple[List[Document], dict]: @@ -267,9 +272,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): input_documents=docs, callbacks=callbacks, **kwargs ) - while num_tokens is not None and num_tokens > token_max: + _token_max = token_max or self.token_max + while num_tokens is not None and num_tokens > _token_max: new_result_doc_list = _split_list_of_docs( - result_docs, length_func, token_max, **kwargs + result_docs, length_func, _token_max, **kwargs ) result_docs = [] for docs in new_result_doc_list: diff --git a/langchain/chains/qa_with_sources/loading.py b/langchain/chains/qa_with_sources/loading.py index 9a5058938c7..f8f8ac5f7e5 100644 --- a/langchain/chains/qa_with_sources/loading.py +++ b/langchain/chains/qa_with_sources/loading.py @@ -79,6 +79,7 @@ def _load_map_reduce_chain( reduce_llm: Optional[BaseLanguageModel] = None, collapse_llm: Optional[BaseLanguageModel] = None, verbose: Optional[bool] = None, + token_max: int = 3000, **kwargs: Any, ) -> MapReduceDocumentsChain: map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) @@ -111,6 +112,8 @@ def _load_map_reduce_chain( reduce_documents_chain = ReduceDocumentsChain( combine_documents_chain=combine_documents_chain, collapse_documents_chain=collapse_chain, + token_max=token_max, + verbose=verbose, ) return MapReduceDocumentsChain( llm_chain=map_chain, diff --git a/langchain/chains/question_answering/__init__.py b/langchain/chains/question_answering/__init__.py index 9fa2d2bffd7..8c2ed5d6530 100644 --- a/langchain/chains/question_answering/__init__.py +++ b/langchain/chains/question_answering/__init__.py @@ -99,6 +99,7 @@ def _load_map_reduce_chain( verbose: Optional[bool] = None, callback_manager: Optional[BaseCallbackManager] = None, callbacks: Callbacks = None, + token_max: int = 3000, **kwargs: Any, ) -> MapReduceDocumentsChain: _question_prompt = ( @@ -154,6 +155,8 @@ def _load_map_reduce_chain( reduce_documents_chain = ReduceDocumentsChain( combine_documents_chain=combine_documents_chain, collapse_documents_chain=collapse_chain, + token_max=token_max, + verbose=verbose, ) return MapReduceDocumentsChain( llm_chain=map_chain, diff --git a/langchain/chains/summarize/__init__.py b/langchain/chains/summarize/__init__.py index 211bf701d98..5645b73fc52 100644 --- a/langchain/chains/summarize/__init__.py +++ b/langchain/chains/summarize/__init__.py @@ -48,6 +48,7 @@ def _load_map_reduce_chain( reduce_llm: Optional[BaseLanguageModel] = None, collapse_llm: Optional[BaseLanguageModel] = None, verbose: Optional[bool] = None, + token_max: int = 3000, **kwargs: Any, ) -> MapReduceDocumentsChain: map_chain = LLMChain(llm=llm, prompt=map_prompt, verbose=verbose) @@ -79,6 +80,8 @@ def _load_map_reduce_chain( reduce_documents_chain = ReduceDocumentsChain( combine_documents_chain=combine_documents_chain, collapse_documents_chain=collapse_chain, + token_max=token_max, + verbose=verbose, ) return MapReduceDocumentsChain( llm_chain=map_chain,