mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 13:36:15 +00:00
fix(text-splitters): add validation to prevent infinite loop and prevent empty token splitter (#32205)
### Description 1) Add validation to prevent infinite loop condition when ```tokenizer.tokens_per_chunk > tokenizer.chunk_overlap``` 2) Avoid empty decoded chunk when splitter appends tokens --------- Co-authored-by: Eugene Yurtsev <eugene@langchain.dev> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
@@ -354,13 +354,19 @@ def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:
|
||||
splits: list[str] = []
|
||||
input_ids = tokenizer.encode(text)
|
||||
start_idx = 0
|
||||
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
if tokenizer.tokens_per_chunk <= tokenizer.chunk_overlap:
|
||||
msg = "tokens_per_chunk must be greater than chunk_overlap"
|
||||
raise ValueError(msg)
|
||||
|
||||
while start_idx < len(input_ids):
|
||||
splits.append(tokenizer.decode(chunk_ids))
|
||||
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
if not chunk_ids:
|
||||
break
|
||||
decoded = tokenizer.decode(chunk_ids)
|
||||
if decoded:
|
||||
splits.append(decoded)
|
||||
if cur_idx == len(input_ids):
|
||||
break
|
||||
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
|
||||
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
return splits
|
||||
|
@@ -2849,6 +2849,21 @@ def test_split_text_on_tokens() -> None:
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_decode_returns_no_chunks() -> None:
|
||||
"""Test that when decode returns only empty strings, output is empty, not ['']."""
|
||||
text = "foo bar baz 123"
|
||||
|
||||
tokenizer = Tokenizer(
|
||||
chunk_overlap=3,
|
||||
tokens_per_chunk=7,
|
||||
decode=(lambda _: ""),
|
||||
encode=(lambda it: [ord(c) for c in it]),
|
||||
)
|
||||
output = split_text_on_tokens(text=text, tokenizer=tokenizer)
|
||||
expected_output: list[Any] = []
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
@pytest.mark.requires("bs4")
|
||||
@pytest.mark.requires("lxml")
|
||||
def test_section_aware_happy_path_splitting_based_on_header_1_2() -> None:
|
||||
|
Reference in New Issue
Block a user