mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-15 23:57:21 +00:00
langchain: improve performance split_list_of_docs
This commit is contained in:
parent
f005988e31
commit
fb2ea169f2
@ -27,7 +27,11 @@ class AsyncCombineDocsProtocol(Protocol):
|
|||||||
|
|
||||||
|
|
||||||
def split_list_of_docs(
|
def split_list_of_docs(
|
||||||
docs: list[Document], length_func: Callable, token_max: int, **kwargs: Any
|
docs: list[Document],
|
||||||
|
length_func: Callable,
|
||||||
|
token_max: int,
|
||||||
|
acum_length: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
) -> list[list[Document]]:
|
) -> list[list[Document]]:
|
||||||
"""Split Documents into subsets that each meet a cumulative length constraint.
|
"""Split Documents into subsets that each meet a cumulative length constraint.
|
||||||
|
|
||||||
@ -35,6 +39,17 @@ def split_list_of_docs(
|
|||||||
docs: The full list of Documents.
|
docs: The full list of Documents.
|
||||||
length_func: Function for computing the cumulative length of a set of Documents.
|
length_func: Function for computing the cumulative length of a set of Documents.
|
||||||
token_max: The maximum cumulative length of any subset of Documents.
|
token_max: The maximum cumulative length of any subset of Documents.
|
||||||
|
acum_length: If True, the length_func is accumulated over the documents,
|
||||||
|
minimizing calls.
|
||||||
|
This implies that:
|
||||||
|
```
|
||||||
|
length_fuc([a,..,n]) =
|
||||||
|
length_func([a]) + ... + length_func([n]) + length_func([])
|
||||||
|
```
|
||||||
|
The last term (with empty list) might be used to adjust a border condition.
|
||||||
|
(e.g. it could remove some tokens from the last document).
|
||||||
|
This optimizes computation time from O(n2) to O(n).
|
||||||
|
Default is False (backwards compatible).
|
||||||
**kwargs: Arbitrary additional keyword params to pass to each call of the
|
**kwargs: Arbitrary additional keyword params to pass to each call of the
|
||||||
length_func.
|
length_func.
|
||||||
|
|
||||||
@ -43,9 +58,15 @@ def split_list_of_docs(
|
|||||||
"""
|
"""
|
||||||
new_result_doc_list = []
|
new_result_doc_list = []
|
||||||
_sub_result_docs = []
|
_sub_result_docs = []
|
||||||
|
_num_tokens = length_func([], **kwargs) if acum_length else 0
|
||||||
|
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
_sub_result_docs.append(doc)
|
_sub_result_docs.append(doc)
|
||||||
_num_tokens = length_func(_sub_result_docs, **kwargs)
|
if acum_length:
|
||||||
|
_partial_tokens = length_func([doc], **kwargs)
|
||||||
|
_num_tokens += _partial_tokens
|
||||||
|
else:
|
||||||
|
_num_tokens = length_func(_sub_result_docs, **kwargs)
|
||||||
if _num_tokens > token_max:
|
if _num_tokens > token_max:
|
||||||
if len(_sub_result_docs) == 1:
|
if len(_sub_result_docs) == 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -54,6 +75,9 @@ def split_list_of_docs(
|
|||||||
)
|
)
|
||||||
new_result_doc_list.append(_sub_result_docs[:-1])
|
new_result_doc_list.append(_sub_result_docs[:-1])
|
||||||
_sub_result_docs = _sub_result_docs[-1:]
|
_sub_result_docs = _sub_result_docs[-1:]
|
||||||
|
_num_tokens = (
|
||||||
|
(length_func([], **kwargs) + _partial_tokens) if acum_length else 0
|
||||||
|
)
|
||||||
new_result_doc_list.append(_sub_result_docs)
|
new_result_doc_list.append(_sub_result_docs)
|
||||||
return new_result_doc_list
|
return new_result_doc_list
|
||||||
|
|
||||||
|
@ -48,7 +48,8 @@ def test__split_list_double_doc() -> None:
|
|||||||
assert doc_list == [docs]
|
assert doc_list == [docs]
|
||||||
|
|
||||||
|
|
||||||
def test__split_list_works_correctly() -> None:
|
@pytest.mark.parametrize("acum_length", [False, True])
|
||||||
|
def test__split_list_works_correctly(acum_length: bool) -> None:
|
||||||
"""Test splitting works correctly."""
|
"""Test splitting works correctly."""
|
||||||
docs = [
|
docs = [
|
||||||
Document(page_content="foo"),
|
Document(page_content="foo"),
|
||||||
@ -58,7 +59,7 @@ def test__split_list_works_correctly() -> None:
|
|||||||
Document(page_content="bar"),
|
Document(page_content="bar"),
|
||||||
Document(page_content="baz"),
|
Document(page_content="baz"),
|
||||||
]
|
]
|
||||||
doc_list = split_list_of_docs(docs, _fake_docs_len_func, 10)
|
doc_list = split_list_of_docs(docs, _fake_docs_len_func, 10, acum_length)
|
||||||
expected_result = [
|
expected_result = [
|
||||||
# Test a group of three.
|
# Test a group of three.
|
||||||
[
|
[
|
||||||
|
Loading…
Reference in New Issue
Block a user