mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 07:07:34 +00:00
langchain: improve performance split_list_of_docs
This commit is contained in:
parent
f005988e31
commit
fb2ea169f2
@ -27,7 +27,11 @@ class AsyncCombineDocsProtocol(Protocol):
|
||||
|
||||
|
||||
def split_list_of_docs(
|
||||
docs: list[Document], length_func: Callable, token_max: int, **kwargs: Any
|
||||
docs: list[Document],
|
||||
length_func: Callable,
|
||||
token_max: int,
|
||||
acum_length: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> list[list[Document]]:
|
||||
"""Split Documents into subsets that each meet a cumulative length constraint.
|
||||
|
||||
@ -35,6 +39,17 @@ def split_list_of_docs(
|
||||
docs: The full list of Documents.
|
||||
length_func: Function for computing the cumulative length of a set of Documents.
|
||||
token_max: The maximum cumulative length of any subset of Documents.
|
||||
acum_length: If True, the length_func is accumulated over the documents,
|
||||
minimizing calls.
|
||||
This implies that:
|
||||
```
|
||||
length_fuc([a,..,n]) =
|
||||
length_func([a]) + ... + length_func([n]) + length_func([])
|
||||
```
|
||||
The last term (with empty list) might be used to adjust a border condition.
|
||||
(e.g. it could remove some tokens from the last document).
|
||||
This optimizes computation time from O(n2) to O(n).
|
||||
Default is False (backwards compatible).
|
||||
**kwargs: Arbitrary additional keyword params to pass to each call of the
|
||||
length_func.
|
||||
|
||||
@ -43,9 +58,15 @@ def split_list_of_docs(
|
||||
"""
|
||||
new_result_doc_list = []
|
||||
_sub_result_docs = []
|
||||
_num_tokens = length_func([], **kwargs) if acum_length else 0
|
||||
|
||||
for doc in docs:
|
||||
_sub_result_docs.append(doc)
|
||||
_num_tokens = length_func(_sub_result_docs, **kwargs)
|
||||
if acum_length:
|
||||
_partial_tokens = length_func([doc], **kwargs)
|
||||
_num_tokens += _partial_tokens
|
||||
else:
|
||||
_num_tokens = length_func(_sub_result_docs, **kwargs)
|
||||
if _num_tokens > token_max:
|
||||
if len(_sub_result_docs) == 1:
|
||||
raise ValueError(
|
||||
@ -54,6 +75,9 @@ def split_list_of_docs(
|
||||
)
|
||||
new_result_doc_list.append(_sub_result_docs[:-1])
|
||||
_sub_result_docs = _sub_result_docs[-1:]
|
||||
_num_tokens = (
|
||||
(length_func([], **kwargs) + _partial_tokens) if acum_length else 0
|
||||
)
|
||||
new_result_doc_list.append(_sub_result_docs)
|
||||
return new_result_doc_list
|
||||
|
||||
|
@ -48,7 +48,8 @@ def test__split_list_double_doc() -> None:
|
||||
assert doc_list == [docs]
|
||||
|
||||
|
||||
def test__split_list_works_correctly() -> None:
|
||||
@pytest.mark.parametrize("acum_length", [False, True])
|
||||
def test__split_list_works_correctly(acum_length: bool) -> None:
|
||||
"""Test splitting works correctly."""
|
||||
docs = [
|
||||
Document(page_content="foo"),
|
||||
@ -58,7 +59,7 @@ def test__split_list_works_correctly() -> None:
|
||||
Document(page_content="bar"),
|
||||
Document(page_content="baz"),
|
||||
]
|
||||
doc_list = split_list_of_docs(docs, _fake_docs_len_func, 10)
|
||||
doc_list = split_list_of_docs(docs, _fake_docs_len_func, 10, acum_length)
|
||||
expected_result = [
|
||||
# Test a group of three.
|
||||
[
|
||||
|
Loading…
Reference in New Issue
Block a user