mirror of
https://github.com/hwchase17/langchain.git
synced 2025-11-16 05:30:23 +00:00
fix(text-splitters): add validation to prevent infinite loop and prevent empty token splitter (#32205)
### Description 1) Add validation to prevent infinite loop condition when ```tokenizer.tokens_per_chunk > tokenizer.chunk_overlap``` 2) Avoid empty decoded chunk when splitter appends tokens --------- Co-authored-by: Eugene Yurtsev <eugene@langchain.dev> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
@@ -354,13 +354,19 @@ def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:
|
||||
splits: list[str] = []
|
||||
input_ids = tokenizer.encode(text)
|
||||
start_idx = 0
|
||||
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
if tokenizer.tokens_per_chunk <= tokenizer.chunk_overlap:
|
||||
msg = "tokens_per_chunk must be greater than chunk_overlap"
|
||||
raise ValueError(msg)
|
||||
|
||||
while start_idx < len(input_ids):
|
||||
splits.append(tokenizer.decode(chunk_ids))
|
||||
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
if not chunk_ids:
|
||||
break
|
||||
decoded = tokenizer.decode(chunk_ids)
|
||||
if decoded:
|
||||
splits.append(decoded)
|
||||
if cur_idx == len(input_ids):
|
||||
break
|
||||
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
|
||||
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
return splits
|
||||
|
||||
Reference in New Issue
Block a user