diff --git a/libs/langchain/langchain/chains/combine_documents/reduce.py b/libs/langchain/langchain/chains/combine_documents/reduce.py index 4b684357f5d..e6f05f0327e 100644 --- a/libs/langchain/langchain/chains/combine_documents/reduce.py +++ b/libs/langchain/langchain/chains/combine_documents/reduce.py @@ -27,7 +27,11 @@ class AsyncCombineDocsProtocol(Protocol): def split_list_of_docs( - docs: list[Document], length_func: Callable, token_max: int, **kwargs: Any + docs: list[Document], + length_func: Callable, + token_max: int, + acum_length: bool = False, + **kwargs: Any, ) -> list[list[Document]]: """Split Documents into subsets that each meet a cumulative length constraint. @@ -35,6 +39,17 @@ def split_list_of_docs( docs: The full list of Documents. length_func: Function for computing the cumulative length of a set of Documents. token_max: The maximum cumulative length of any subset of Documents. + acum_length: If True, the length_func is accumulated over the documents, + minimizing calls. + This implies that: + ``` + length_fuc([a,..,n]) = + length_func([a]) + ... + length_func([n]) + length_func([]) + ``` + The last term (with empty list) might be used to adjust a border condition. + (e.g. it could remove some tokens from the last document). + This optimizes computation time from O(n2) to O(n). + Default is False (backwards compatible). **kwargs: Arbitrary additional keyword params to pass to each call of the length_func. @@ -43,9 +58,15 @@ def split_list_of_docs( """ new_result_doc_list = [] _sub_result_docs = [] + _num_tokens = length_func([], **kwargs) if acum_length else 0 + for doc in docs: _sub_result_docs.append(doc) - _num_tokens = length_func(_sub_result_docs, **kwargs) + if acum_length: + _partial_tokens = length_func([doc], **kwargs) + _num_tokens += _partial_tokens + else: + _num_tokens = length_func(_sub_result_docs, **kwargs) if _num_tokens > token_max: if len(_sub_result_docs) == 1: raise ValueError( @@ -54,6 +75,9 @@ def split_list_of_docs( ) new_result_doc_list.append(_sub_result_docs[:-1]) _sub_result_docs = _sub_result_docs[-1:] + _num_tokens = ( + (length_func([], **kwargs) + _partial_tokens) if acum_length else 0 + ) new_result_doc_list.append(_sub_result_docs) return new_result_doc_list diff --git a/libs/langchain/tests/unit_tests/chains/test_combine_documents.py b/libs/langchain/tests/unit_tests/chains/test_combine_documents.py index 655a13445bc..14ba5d1387c 100644 --- a/libs/langchain/tests/unit_tests/chains/test_combine_documents.py +++ b/libs/langchain/tests/unit_tests/chains/test_combine_documents.py @@ -48,7 +48,8 @@ def test__split_list_double_doc() -> None: assert doc_list == [docs] -def test__split_list_works_correctly() -> None: +@pytest.mark.parametrize("acum_length", [False, True]) +def test__split_list_works_correctly(acum_length: bool) -> None: """Test splitting works correctly.""" docs = [ Document(page_content="foo"), @@ -58,7 +59,7 @@ def test__split_list_works_correctly() -> None: Document(page_content="bar"), Document(page_content="baz"), ] - doc_list = split_list_of_docs(docs, _fake_docs_len_func, 10) + doc_list = split_list_of_docs(docs, _fake_docs_len_func, 10, acum_length) expected_result = [ # Test a group of three. [