text-splitters: Inconsistent results with NLTKTextSplitter's add_start_index=True (#27782)

This PR closes #27781

# Problem
The current implementation of `NLTKTextSplitter` is using
`sent_tokenize`. However, this `sent_tokenize` doesn't handle chars
between 2 tokenized sentences... hence, this behavior throws errors when
we are using `add_start_index=True`, as described in issue #27781. In
particular:
```python
from nltk.tokenize import sent_tokenize

output1 = sent_tokenize("Innovation drives our success. Collaboration fosters creative solutions. Efficiency enhances data management.", language="english")
print(output1)
output2 = sent_tokenize("Innovation drives our success.        Collaboration fosters creative solutions. Efficiency enhances data management.", language="english")
print(output2)
>>> ['Innovation drives our success.', 'Collaboration fosters creative solutions.', 'Efficiency enhances data management.']
>>> ['Innovation drives our success.', 'Collaboration fosters creative solutions.', 'Efficiency enhances data management.']
```

# Solution
With this new `use_span_tokenize` parameter, we can use NLTK to create
sentences (with `span_tokenize`), but also add extra chars to be sure
that we still can map the chunks to the original text.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
Co-authored-by: Erick Friis <erickfriis@gmail.com>
This commit is contained in:
Antonio Lanza 2024-12-16 20:53:15 +01:00 committed by GitHub
parent d262d41cc0
commit b2102b8cc4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 1848 additions and 27 deletions

View File

@ -9,6 +9,9 @@ TEST_FILE ?= tests/unit_tests/
test tests: test tests:
poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE) poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE)
integration_test integration_tests:
poetry run pytest tests/integration_tests/
test_watch: test_watch:
poetry run ptw --snapshot-update --now . -- -vv -x tests/unit_tests poetry run ptw --snapshot-update --now . -- -vv -x tests/unit_tests

View File

@ -9,23 +9,47 @@ class NLTKTextSplitter(TextSplitter):
"""Splitting text using NLTK package.""" """Splitting text using NLTK package."""
def __init__( def __init__(
self, separator: str = "\n\n", language: str = "english", **kwargs: Any self,
separator: str = "\n\n",
language: str = "english",
*,
use_span_tokenize: bool = False,
**kwargs: Any,
) -> None: ) -> None:
"""Initialize the NLTK splitter.""" """Initialize the NLTK splitter."""
super().__init__(**kwargs) super().__init__(**kwargs)
self._separator = separator
self._language = language
self._use_span_tokenize = use_span_tokenize
if self._use_span_tokenize and self._separator != "":
raise ValueError("When use_span_tokenize is True, separator should be ''")
try: try:
from nltk.tokenize import sent_tokenize if self._use_span_tokenize:
from nltk.tokenize import _get_punkt_tokenizer
self._tokenizer = sent_tokenize self._tokenizer = _get_punkt_tokenizer(self._language)
else:
from nltk.tokenize import sent_tokenize
self._tokenizer = sent_tokenize
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"NLTK is not installed, please install it with `pip install nltk`." "NLTK is not installed, please install it with `pip install nltk`."
) )
self._separator = separator
self._language = language
def split_text(self, text: str) -> List[str]: def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks.""" """Split incoming text and return chunks."""
# First we naively split the large input into a bunch of smaller ones. # First we naively split the large input into a bunch of smaller ones.
splits = self._tokenizer(text, language=self._language) if self._use_span_tokenize:
spans = list(self._tokenizer.span_tokenize(text))
splits = []
for i, (start, end) in enumerate(spans):
if i > 0:
prev_end = spans[i - 1][1]
sentence = text[prev_end:start] + text[start:end]
else:
sentence = text[start:end]
splits.append(sentence)
else:
splits = self._tokenizer(text, language=self._language)
return self._merge_splits(splits, self._separator) return self._merge_splits(splits, self._separator)

View File

@ -22,7 +22,7 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Could not import sentence_transformer python package. " "Could not import sentence_transformers python package. "
"This is needed in order to for SentenceTransformersTokenTextSplitter. " "This is needed in order to for SentenceTransformersTokenTextSplitter. "
"Please install it with `pip install sentence-transformers`." "Please install it with `pip install sentence-transformers`."
) )

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,5 @@
[build-system] [build-system]
requires = [ "poetry-core>=1.0.0",] requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.poetry] [tool.poetry]
@ -14,7 +14,20 @@ repository = "https://github.com/langchain-ai/langchain"
[tool.mypy] [tool.mypy]
disallow_untyped_defs = "True" disallow_untyped_defs = "True"
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
module = [ "transformers", "sentence_transformers", "nltk.tokenize", "konlpy.tag", "bs4", "pytest", "spacy", "spacy.lang.en", "numpy",] module = [
"transformers",
"sentence_transformers",
"nltk.tokenize",
"konlpy.tag",
"bs4",
"pytest",
"spacy",
"spacy.lang.en",
"numpy",
"nltk",
"spacy.cli",
"torch",
]
ignore_missing_imports = "True" ignore_missing_imports = "True"
[tool.poetry.urls] [tool.poetry.urls]
@ -26,15 +39,18 @@ python = ">=3.9,<4.0"
langchain-core = "^0.3.25" langchain-core = "^0.3.25"
[tool.ruff.lint] [tool.ruff.lint]
select = [ "E", "F", "I", "T201", "D",] select = ["E", "F", "I", "T201", "D"]
ignore = [ "D100",] ignore = ["D100"]
[tool.coverage.run] [tool.coverage.run]
omit = [ "tests/*",] omit = ["tests/*"]
[tool.pytest.ini_options] [tool.pytest.ini_options]
addopts = "--strict-markers --strict-config --durations=5" addopts = "--strict-markers --strict-config --durations=5"
markers = [ "requires: mark tests as requiring a specific library", "compile: mark placeholder test used to compile integration tests without running them",] markers = [
"requires: mark tests as requiring a specific library",
"compile: mark placeholder test used to compile integration tests without running them",
]
asyncio_mode = "auto" asyncio_mode = "auto"
[tool.poetry.group.lint] [tool.poetry.group.lint]
@ -53,19 +69,17 @@ optional = true
convention = "google" convention = "google"
[tool.ruff.lint.per-file-ignores] [tool.ruff.lint.per-file-ignores]
"tests/**" = [ "D",] "tests/**" = ["D"]
[tool.poetry.group.lint.dependencies] [tool.poetry.group.lint.dependencies]
ruff = "^0.5" ruff = "^0.5"
[tool.poetry.group.typing.dependencies] [tool.poetry.group.typing.dependencies]
mypy = "^1.10" mypy = "^1.10"
lxml-stubs = "^0.5.1" lxml-stubs = "^0.5.1"
types-requests = "^2.31.0.20240218" types-requests = "^2.31.0.20240218"
tiktoken = "^0.8.0" tiktoken = "^0.8.0"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
jupyter = "^1.0.0" jupyter = "^1.0.0"
@ -78,20 +92,23 @@ pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1" pytest-asyncio = "^0.21.1"
pytest-socket = "^0.7.0" pytest-socket = "^0.7.0"
[tool.poetry.group.test_integration]
optional = true
[tool.poetry.group.test_integration.dependencies] [tool.poetry.group.test_integration.dependencies]
spacy = { version = "*", python = "<3.13" }
nltk = "^3.9.1"
transformers = "^4.47.0"
sentence-transformers = { version = ">=2.6.0", python = "<3.13" }
[tool.poetry.group.lint.dependencies.langchain-core] [tool.poetry.group.lint.dependencies.langchain-core]
path = "../core" path = "../core"
develop = true develop = true
[tool.poetry.group.dev.dependencies.langchain-core] [tool.poetry.group.dev.dependencies.langchain-core]
path = "../core" path = "../core"
develop = true develop = true
[tool.poetry.group.test.dependencies.langchain-core] [tool.poetry.group.test.dependencies.langchain-core]
path = "../core" path = "../core"
develop = true develop = true

View File

@ -1,18 +1,36 @@
"""Test text splitting functionality using NLTK and Spacy based sentence splitters.""" """Test text splitting functionality using NLTK and Spacy based sentence splitters."""
from typing import Any
import nltk
import pytest import pytest
from langchain_core.documents import Document
from langchain_text_splitters.nltk import NLTKTextSplitter from langchain_text_splitters.nltk import NLTKTextSplitter
from langchain_text_splitters.spacy import SpacyTextSplitter from langchain_text_splitters.spacy import SpacyTextSplitter
def setup_module() -> None:
nltk.download("punkt_tab")
@pytest.fixture()
def spacy() -> Any:
try:
import spacy
except ImportError:
pytest.skip("Spacy not installed.")
spacy.cli.download("en_core_web_sm") # type: ignore
return spacy
def test_nltk_text_splitting_args() -> None: def test_nltk_text_splitting_args() -> None:
"""Test invalid arguments.""" """Test invalid arguments."""
with pytest.raises(ValueError): with pytest.raises(ValueError):
NLTKTextSplitter(chunk_size=2, chunk_overlap=4) NLTKTextSplitter(chunk_size=2, chunk_overlap=4)
def test_spacy_text_splitting_args() -> None: def test_spacy_text_splitting_args(spacy: Any) -> None:
"""Test invalid arguments.""" """Test invalid arguments."""
with pytest.raises(ValueError): with pytest.raises(ValueError):
SpacyTextSplitter(chunk_size=2, chunk_overlap=4) SpacyTextSplitter(chunk_size=2, chunk_overlap=4)
@ -29,7 +47,7 @@ def test_nltk_text_splitter() -> None:
@pytest.mark.parametrize("pipeline", ["sentencizer", "en_core_web_sm"]) @pytest.mark.parametrize("pipeline", ["sentencizer", "en_core_web_sm"])
def test_spacy_text_splitter(pipeline: str) -> None: def test_spacy_text_splitter(pipeline: str, spacy: Any) -> None:
"""Test splitting by sentence using Spacy.""" """Test splitting by sentence using Spacy."""
text = "This is sentence one. And this is sentence two." text = "This is sentence one. And this is sentence two."
separator = "|||" separator = "|||"
@ -40,7 +58,7 @@ def test_spacy_text_splitter(pipeline: str) -> None:
@pytest.mark.parametrize("pipeline", ["sentencizer", "en_core_web_sm"]) @pytest.mark.parametrize("pipeline", ["sentencizer", "en_core_web_sm"])
def test_spacy_text_splitter_strip_whitespace(pipeline: str) -> None: def test_spacy_text_splitter_strip_whitespace(pipeline: str, spacy: Any) -> None:
"""Test splitting by sentence using Spacy.""" """Test splitting by sentence using Spacy."""
text = "This is sentence one. And this is sentence two." text = "This is sentence one. And this is sentence two."
separator = "|||" separator = "|||"
@ -50,3 +68,35 @@ def test_spacy_text_splitter_strip_whitespace(pipeline: str) -> None:
output = splitter.split_text(text) output = splitter.split_text(text)
expected_output = [f"This is sentence one. {separator}And this is sentence two."] expected_output = [f"This is sentence one. {separator}And this is sentence two."]
assert output == expected_output assert output == expected_output
def test_nltk_text_splitter_args() -> None:
"""Test invalid arguments for NLTKTextSplitter."""
with pytest.raises(ValueError):
NLTKTextSplitter(
chunk_size=80,
chunk_overlap=0,
separator="\n\n",
use_span_tokenize=True,
)
def test_nltk_text_splitter_with_add_start_index() -> None:
splitter = NLTKTextSplitter(
chunk_size=80,
chunk_overlap=0,
separator="",
use_span_tokenize=True,
add_start_index=True,
)
txt = (
"Innovation drives our success. "
"Collaboration fosters creative solutions. "
"Efficiency enhances data management."
)
docs = [Document(txt)]
chunks = splitter.split_documents(docs)
assert len(chunks) == 2
for chunk in chunks:
s_i = chunk.metadata["start_index"]
assert chunk.page_content == txt[s_i : s_i + len(chunk.page_content)]

View File

@ -1,5 +1,7 @@
"""Test text splitters that require an integration.""" """Test text splitters that require an integration."""
from typing import Any
import pytest import pytest
from langchain_text_splitters import ( from langchain_text_splitters import (
@ -11,6 +13,15 @@ from langchain_text_splitters.sentence_transformers import (
) )
@pytest.fixture()
def sentence_transformers() -> Any:
try:
import sentence_transformers
except ImportError:
pytest.skip("SentenceTransformers not installed.")
return sentence_transformers
def test_huggingface_type_check() -> None: def test_huggingface_type_check() -> None:
"""Test that type checks are done properly on input.""" """Test that type checks are done properly on input."""
with pytest.raises(ValueError): with pytest.raises(ValueError):
@ -52,7 +63,7 @@ def test_token_text_splitter_from_tiktoken() -> None:
assert expected_tokenizer == actual_tokenizer assert expected_tokenizer == actual_tokenizer
def test_sentence_transformers_count_tokens() -> None: def test_sentence_transformers_count_tokens(sentence_transformers: Any) -> None:
splitter = SentenceTransformersTokenTextSplitter( splitter = SentenceTransformersTokenTextSplitter(
model_name="sentence-transformers/paraphrase-albert-small-v2" model_name="sentence-transformers/paraphrase-albert-small-v2"
) )
@ -67,7 +78,7 @@ def test_sentence_transformers_count_tokens() -> None:
assert expected_token_count == token_count assert expected_token_count == token_count
def test_sentence_transformers_split_text() -> None: def test_sentence_transformers_split_text(sentence_transformers: Any) -> None:
splitter = SentenceTransformersTokenTextSplitter( splitter = SentenceTransformersTokenTextSplitter(
model_name="sentence-transformers/paraphrase-albert-small-v2" model_name="sentence-transformers/paraphrase-albert-small-v2"
) )
@ -77,7 +88,7 @@ def test_sentence_transformers_split_text() -> None:
assert expected_text_chunks == text_chunks assert expected_text_chunks == text_chunks
def test_sentence_transformers_multiple_tokens() -> None: def test_sentence_transformers_multiple_tokens(sentence_transformers: Any) -> None:
splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0) splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0)
text = "Lorem " text = "Lorem "