langchain/libs/text-splitters/langchain_text_splitters/nltk.py
Antonio Lanza b2102b8cc4
text-splitters: Inconsistent results with NLTKTextSplitter's add_start_index=True (#27782)
This PR closes #27781

# Problem
The current implementation of `NLTKTextSplitter` is using
`sent_tokenize`. However, this `sent_tokenize` doesn't handle chars
between 2 tokenized sentences... hence, this behavior throws errors when
we are using `add_start_index=True`, as described in issue #27781. In
particular:
```python
from nltk.tokenize import sent_tokenize

output1 = sent_tokenize("Innovation drives our success. Collaboration fosters creative solutions. Efficiency enhances data management.", language="english")
print(output1)
output2 = sent_tokenize("Innovation drives our success.        Collaboration fosters creative solutions. Efficiency enhances data management.", language="english")
print(output2)
>>> ['Innovation drives our success.', 'Collaboration fosters creative solutions.', 'Efficiency enhances data management.']
>>> ['Innovation drives our success.', 'Collaboration fosters creative solutions.', 'Efficiency enhances data management.']
```

# Solution
With this new `use_span_tokenize` parameter, we can use NLTK to create
sentences (with `span_tokenize`), but also add extra chars to be sure
that we still can map the chunks to the original text.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
Co-authored-by: Erick Friis <erickfriis@gmail.com>
2024-12-16 19:53:15 +00:00

56 lines
1.9 KiB
Python

from __future__ import annotations
from typing import Any, List
from langchain_text_splitters.base import TextSplitter
class NLTKTextSplitter(TextSplitter):
"""Splitting text using NLTK package."""
def __init__(
self,
separator: str = "\n\n",
language: str = "english",
*,
use_span_tokenize: bool = False,
**kwargs: Any,
) -> None:
"""Initialize the NLTK splitter."""
super().__init__(**kwargs)
self._separator = separator
self._language = language
self._use_span_tokenize = use_span_tokenize
if self._use_span_tokenize and self._separator != "":
raise ValueError("When use_span_tokenize is True, separator should be ''")
try:
if self._use_span_tokenize:
from nltk.tokenize import _get_punkt_tokenizer
self._tokenizer = _get_punkt_tokenizer(self._language)
else:
from nltk.tokenize import sent_tokenize
self._tokenizer = sent_tokenize
except ImportError:
raise ImportError(
"NLTK is not installed, please install it with `pip install nltk`."
)
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
# First we naively split the large input into a bunch of smaller ones.
if self._use_span_tokenize:
spans = list(self._tokenizer.span_tokenize(text))
splits = []
for i, (start, end) in enumerate(spans):
if i > 0:
prev_end = spans[i - 1][1]
sentence = text[prev_end:start] + text[start:end]
else:
sentence = text[start:end]
splits.append(sentence)
else:
splits = self._tokenizer(text, language=self._language)
return self._merge_splits(splits, self._separator)