langchain: improve performance split_list_of_docs

This commit is contained in:
Adrian Panella 2025-04-13 12:38:34 -06:00
parent f005988e31
commit fb2ea169f2
2 changed files with 29 additions and 4 deletions

View File

@ -27,7 +27,11 @@ class AsyncCombineDocsProtocol(Protocol):
def split_list_of_docs(
docs: list[Document], length_func: Callable, token_max: int, **kwargs: Any
docs: list[Document],
length_func: Callable,
token_max: int,
acum_length: bool = False,
**kwargs: Any,
) -> list[list[Document]]:
"""Split Documents into subsets that each meet a cumulative length constraint.
@ -35,6 +39,17 @@ def split_list_of_docs(
docs: The full list of Documents.
length_func: Function for computing the cumulative length of a set of Documents.
token_max: The maximum cumulative length of any subset of Documents.
acum_length: If True, the length_func is accumulated over the documents,
minimizing calls.
This implies that:
```
length_fuc([a,..,n]) =
length_func([a]) + ... + length_func([n]) + length_func([])
```
The last term (with empty list) might be used to adjust a border condition.
(e.g. it could remove some tokens from the last document).
This optimizes computation time from O(n2) to O(n).
Default is False (backwards compatible).
**kwargs: Arbitrary additional keyword params to pass to each call of the
length_func.
@ -43,9 +58,15 @@ def split_list_of_docs(
"""
new_result_doc_list = []
_sub_result_docs = []
_num_tokens = length_func([], **kwargs) if acum_length else 0
for doc in docs:
_sub_result_docs.append(doc)
_num_tokens = length_func(_sub_result_docs, **kwargs)
if acum_length:
_partial_tokens = length_func([doc], **kwargs)
_num_tokens += _partial_tokens
else:
_num_tokens = length_func(_sub_result_docs, **kwargs)
if _num_tokens > token_max:
if len(_sub_result_docs) == 1:
raise ValueError(
@ -54,6 +75,9 @@ def split_list_of_docs(
)
new_result_doc_list.append(_sub_result_docs[:-1])
_sub_result_docs = _sub_result_docs[-1:]
_num_tokens = (
(length_func([], **kwargs) + _partial_tokens) if acum_length else 0
)
new_result_doc_list.append(_sub_result_docs)
return new_result_doc_list

View File

@ -48,7 +48,8 @@ def test__split_list_double_doc() -> None:
assert doc_list == [docs]
def test__split_list_works_correctly() -> None:
@pytest.mark.parametrize("acum_length", [False, True])
def test__split_list_works_correctly(acum_length: bool) -> None:
"""Test splitting works correctly."""
docs = [
Document(page_content="foo"),
@ -58,7 +59,7 @@ def test__split_list_works_correctly() -> None:
Document(page_content="bar"),
Document(page_content="baz"),
]
doc_list = split_list_of_docs(docs, _fake_docs_len_func, 10)
doc_list = split_list_of_docs(docs, _fake_docs_len_func, 10, acum_length)
expected_result = [
# Test a group of three.
[