fix(text-splitters): add validation to prevent infinite loop and prevent empty token splitter (#32205)

### Description
1) Add validation to prevent infinite loop condition when
```tokenizer.tokens_per_chunk > tokenizer.chunk_overlap```
2) Avoid empty decoded chunk when splitter appends tokens

---------

Co-authored-by: Eugene Yurtsev <eugene@langchain.dev>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Hyunjoon Jeong
2025-09-12 05:55:32 +09:00
committed by GitHub
parent 7e5180e2fa
commit 9cc85387d1
2 changed files with 26 additions and 5 deletions

View File

@@ -354,13 +354,19 @@ def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:
splits: list[str] = []
input_ids = tokenizer.encode(text)
start_idx = 0
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
if tokenizer.tokens_per_chunk <= tokenizer.chunk_overlap:
msg = "tokens_per_chunk must be greater than chunk_overlap"
raise ValueError(msg)
while start_idx < len(input_ids):
splits.append(tokenizer.decode(chunk_ids))
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
if not chunk_ids:
break
decoded = tokenizer.decode(chunk_ids)
if decoded:
splits.append(decoded)
if cur_idx == len(input_ids):
break
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
return splits

View File

@@ -2849,6 +2849,21 @@ def test_split_text_on_tokens() -> None:
assert output == expected_output
def test_decode_returns_no_chunks() -> None:
"""Test that when decode returns only empty strings, output is empty, not ['']."""
text = "foo bar baz 123"
tokenizer = Tokenizer(
chunk_overlap=3,
tokens_per_chunk=7,
decode=(lambda _: ""),
encode=(lambda it: [ord(c) for c in it]),
)
output = split_text_on_tokens(text=text, tokenizer=tokenizer)
expected_output: list[Any] = []
assert output == expected_output
@pytest.mark.requires("bs4")
@pytest.mark.requires("lxml")
def test_section_aware_happy_path_splitting_based_on_header_1_2() -> None: