mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 12:48:12 +00:00
text-splitters: Inconsistent results with NLTKTextSplitter
's add_start_index=True
(#27782)
This PR closes #27781 # Problem The current implementation of `NLTKTextSplitter` is using `sent_tokenize`. However, this `sent_tokenize` doesn't handle chars between 2 tokenized sentences... hence, this behavior throws errors when we are using `add_start_index=True`, as described in issue #27781. In particular: ```python from nltk.tokenize import sent_tokenize output1 = sent_tokenize("Innovation drives our success. Collaboration fosters creative solutions. Efficiency enhances data management.", language="english") print(output1) output2 = sent_tokenize("Innovation drives our success. Collaboration fosters creative solutions. Efficiency enhances data management.", language="english") print(output2) >>> ['Innovation drives our success.', 'Collaboration fosters creative solutions.', 'Efficiency enhances data management.'] >>> ['Innovation drives our success.', 'Collaboration fosters creative solutions.', 'Efficiency enhances data management.'] ``` # Solution With this new `use_span_tokenize` parameter, we can use NLTK to create sentences (with `span_tokenize`), but also add extra chars to be sure that we still can map the chunks to the original text. --------- Co-authored-by: Erick Friis <erick@langchain.dev> Co-authored-by: Erick Friis <erickfriis@gmail.com>
This commit is contained in:
parent
d262d41cc0
commit
b2102b8cc4
@ -9,6 +9,9 @@ TEST_FILE ?= tests/unit_tests/
|
|||||||
test tests:
|
test tests:
|
||||||
poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE)
|
poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE)
|
||||||
|
|
||||||
|
integration_test integration_tests:
|
||||||
|
poetry run pytest tests/integration_tests/
|
||||||
|
|
||||||
test_watch:
|
test_watch:
|
||||||
poetry run ptw --snapshot-update --now . -- -vv -x tests/unit_tests
|
poetry run ptw --snapshot-update --now . -- -vv -x tests/unit_tests
|
||||||
|
|
||||||
|
@ -9,23 +9,47 @@ class NLTKTextSplitter(TextSplitter):
|
|||||||
"""Splitting text using NLTK package."""
|
"""Splitting text using NLTK package."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, separator: str = "\n\n", language: str = "english", **kwargs: Any
|
self,
|
||||||
|
separator: str = "\n\n",
|
||||||
|
language: str = "english",
|
||||||
|
*,
|
||||||
|
use_span_tokenize: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the NLTK splitter."""
|
"""Initialize the NLTK splitter."""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
self._separator = separator
|
||||||
|
self._language = language
|
||||||
|
self._use_span_tokenize = use_span_tokenize
|
||||||
|
if self._use_span_tokenize and self._separator != "":
|
||||||
|
raise ValueError("When use_span_tokenize is True, separator should be ''")
|
||||||
try:
|
try:
|
||||||
from nltk.tokenize import sent_tokenize
|
if self._use_span_tokenize:
|
||||||
|
from nltk.tokenize import _get_punkt_tokenizer
|
||||||
|
|
||||||
self._tokenizer = sent_tokenize
|
self._tokenizer = _get_punkt_tokenizer(self._language)
|
||||||
|
else:
|
||||||
|
from nltk.tokenize import sent_tokenize
|
||||||
|
|
||||||
|
self._tokenizer = sent_tokenize
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"NLTK is not installed, please install it with `pip install nltk`."
|
"NLTK is not installed, please install it with `pip install nltk`."
|
||||||
)
|
)
|
||||||
self._separator = separator
|
|
||||||
self._language = language
|
|
||||||
|
|
||||||
def split_text(self, text: str) -> List[str]:
|
def split_text(self, text: str) -> List[str]:
|
||||||
"""Split incoming text and return chunks."""
|
"""Split incoming text and return chunks."""
|
||||||
# First we naively split the large input into a bunch of smaller ones.
|
# First we naively split the large input into a bunch of smaller ones.
|
||||||
splits = self._tokenizer(text, language=self._language)
|
if self._use_span_tokenize:
|
||||||
|
spans = list(self._tokenizer.span_tokenize(text))
|
||||||
|
splits = []
|
||||||
|
for i, (start, end) in enumerate(spans):
|
||||||
|
if i > 0:
|
||||||
|
prev_end = spans[i - 1][1]
|
||||||
|
sentence = text[prev_end:start] + text[start:end]
|
||||||
|
else:
|
||||||
|
sentence = text[start:end]
|
||||||
|
splits.append(sentence)
|
||||||
|
else:
|
||||||
|
splits = self._tokenizer(text, language=self._language)
|
||||||
return self._merge_splits(splits, self._separator)
|
return self._merge_splits(splits, self._separator)
|
||||||
|
@ -22,7 +22,7 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
|
|||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Could not import sentence_transformer python package. "
|
"Could not import sentence_transformers python package. "
|
||||||
"This is needed in order to for SentenceTransformersTokenTextSplitter. "
|
"This is needed in order to for SentenceTransformersTokenTextSplitter. "
|
||||||
"Please install it with `pip install sentence-transformers`."
|
"Please install it with `pip install sentence-transformers`."
|
||||||
)
|
)
|
||||||
|
1720
libs/text-splitters/poetry.lock
generated
1720
libs/text-splitters/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,5 @@
|
|||||||
[build-system]
|
[build-system]
|
||||||
requires = [ "poetry-core>=1.0.0",]
|
requires = ["poetry-core>=1.0.0"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
@ -14,7 +14,20 @@ repository = "https://github.com/langchain-ai/langchain"
|
|||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
disallow_untyped_defs = "True"
|
disallow_untyped_defs = "True"
|
||||||
[[tool.mypy.overrides]]
|
[[tool.mypy.overrides]]
|
||||||
module = [ "transformers", "sentence_transformers", "nltk.tokenize", "konlpy.tag", "bs4", "pytest", "spacy", "spacy.lang.en", "numpy",]
|
module = [
|
||||||
|
"transformers",
|
||||||
|
"sentence_transformers",
|
||||||
|
"nltk.tokenize",
|
||||||
|
"konlpy.tag",
|
||||||
|
"bs4",
|
||||||
|
"pytest",
|
||||||
|
"spacy",
|
||||||
|
"spacy.lang.en",
|
||||||
|
"numpy",
|
||||||
|
"nltk",
|
||||||
|
"spacy.cli",
|
||||||
|
"torch",
|
||||||
|
]
|
||||||
ignore_missing_imports = "True"
|
ignore_missing_imports = "True"
|
||||||
|
|
||||||
[tool.poetry.urls]
|
[tool.poetry.urls]
|
||||||
@ -26,15 +39,18 @@ python = ">=3.9,<4.0"
|
|||||||
langchain-core = "^0.3.25"
|
langchain-core = "^0.3.25"
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = [ "E", "F", "I", "T201", "D",]
|
select = ["E", "F", "I", "T201", "D"]
|
||||||
ignore = [ "D100",]
|
ignore = ["D100"]
|
||||||
|
|
||||||
[tool.coverage.run]
|
[tool.coverage.run]
|
||||||
omit = [ "tests/*",]
|
omit = ["tests/*"]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
addopts = "--strict-markers --strict-config --durations=5"
|
addopts = "--strict-markers --strict-config --durations=5"
|
||||||
markers = [ "requires: mark tests as requiring a specific library", "compile: mark placeholder test used to compile integration tests without running them",]
|
markers = [
|
||||||
|
"requires: mark tests as requiring a specific library",
|
||||||
|
"compile: mark placeholder test used to compile integration tests without running them",
|
||||||
|
]
|
||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
|
|
||||||
[tool.poetry.group.lint]
|
[tool.poetry.group.lint]
|
||||||
@ -53,19 +69,17 @@ optional = true
|
|||||||
convention = "google"
|
convention = "google"
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
"tests/**" = [ "D",]
|
"tests/**" = ["D"]
|
||||||
|
|
||||||
[tool.poetry.group.lint.dependencies]
|
[tool.poetry.group.lint.dependencies]
|
||||||
ruff = "^0.5"
|
ruff = "^0.5"
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.group.typing.dependencies]
|
[tool.poetry.group.typing.dependencies]
|
||||||
mypy = "^1.10"
|
mypy = "^1.10"
|
||||||
lxml-stubs = "^0.5.1"
|
lxml-stubs = "^0.5.1"
|
||||||
types-requests = "^2.31.0.20240218"
|
types-requests = "^2.31.0.20240218"
|
||||||
tiktoken = "^0.8.0"
|
tiktoken = "^0.8.0"
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
jupyter = "^1.0.0"
|
jupyter = "^1.0.0"
|
||||||
|
|
||||||
@ -78,20 +92,23 @@ pytest-watcher = "^0.3.4"
|
|||||||
pytest-asyncio = "^0.21.1"
|
pytest-asyncio = "^0.21.1"
|
||||||
pytest-socket = "^0.7.0"
|
pytest-socket = "^0.7.0"
|
||||||
|
|
||||||
|
[tool.poetry.group.test_integration]
|
||||||
|
optional = true
|
||||||
|
|
||||||
[tool.poetry.group.test_integration.dependencies]
|
[tool.poetry.group.test_integration.dependencies]
|
||||||
|
spacy = { version = "*", python = "<3.13" }
|
||||||
|
nltk = "^3.9.1"
|
||||||
|
transformers = "^4.47.0"
|
||||||
|
sentence-transformers = { version = ">=2.6.0", python = "<3.13" }
|
||||||
|
|
||||||
[tool.poetry.group.lint.dependencies.langchain-core]
|
[tool.poetry.group.lint.dependencies.langchain-core]
|
||||||
path = "../core"
|
path = "../core"
|
||||||
develop = true
|
develop = true
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies.langchain-core]
|
[tool.poetry.group.dev.dependencies.langchain-core]
|
||||||
path = "../core"
|
path = "../core"
|
||||||
develop = true
|
develop = true
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.group.test.dependencies.langchain-core]
|
[tool.poetry.group.test.dependencies.langchain-core]
|
||||||
path = "../core"
|
path = "../core"
|
||||||
develop = true
|
develop = true
|
||||||
|
@ -1,18 +1,36 @@
|
|||||||
"""Test text splitting functionality using NLTK and Spacy based sentence splitters."""
|
"""Test text splitting functionality using NLTK and Spacy based sentence splitters."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import nltk
|
||||||
import pytest
|
import pytest
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
from langchain_text_splitters.nltk import NLTKTextSplitter
|
from langchain_text_splitters.nltk import NLTKTextSplitter
|
||||||
from langchain_text_splitters.spacy import SpacyTextSplitter
|
from langchain_text_splitters.spacy import SpacyTextSplitter
|
||||||
|
|
||||||
|
|
||||||
|
def setup_module() -> None:
|
||||||
|
nltk.download("punkt_tab")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def spacy() -> Any:
|
||||||
|
try:
|
||||||
|
import spacy
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("Spacy not installed.")
|
||||||
|
spacy.cli.download("en_core_web_sm") # type: ignore
|
||||||
|
return spacy
|
||||||
|
|
||||||
|
|
||||||
def test_nltk_text_splitting_args() -> None:
|
def test_nltk_text_splitting_args() -> None:
|
||||||
"""Test invalid arguments."""
|
"""Test invalid arguments."""
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
NLTKTextSplitter(chunk_size=2, chunk_overlap=4)
|
NLTKTextSplitter(chunk_size=2, chunk_overlap=4)
|
||||||
|
|
||||||
|
|
||||||
def test_spacy_text_splitting_args() -> None:
|
def test_spacy_text_splitting_args(spacy: Any) -> None:
|
||||||
"""Test invalid arguments."""
|
"""Test invalid arguments."""
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
SpacyTextSplitter(chunk_size=2, chunk_overlap=4)
|
SpacyTextSplitter(chunk_size=2, chunk_overlap=4)
|
||||||
@ -29,7 +47,7 @@ def test_nltk_text_splitter() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("pipeline", ["sentencizer", "en_core_web_sm"])
|
@pytest.mark.parametrize("pipeline", ["sentencizer", "en_core_web_sm"])
|
||||||
def test_spacy_text_splitter(pipeline: str) -> None:
|
def test_spacy_text_splitter(pipeline: str, spacy: Any) -> None:
|
||||||
"""Test splitting by sentence using Spacy."""
|
"""Test splitting by sentence using Spacy."""
|
||||||
text = "This is sentence one. And this is sentence two."
|
text = "This is sentence one. And this is sentence two."
|
||||||
separator = "|||"
|
separator = "|||"
|
||||||
@ -40,7 +58,7 @@ def test_spacy_text_splitter(pipeline: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("pipeline", ["sentencizer", "en_core_web_sm"])
|
@pytest.mark.parametrize("pipeline", ["sentencizer", "en_core_web_sm"])
|
||||||
def test_spacy_text_splitter_strip_whitespace(pipeline: str) -> None:
|
def test_spacy_text_splitter_strip_whitespace(pipeline: str, spacy: Any) -> None:
|
||||||
"""Test splitting by sentence using Spacy."""
|
"""Test splitting by sentence using Spacy."""
|
||||||
text = "This is sentence one. And this is sentence two."
|
text = "This is sentence one. And this is sentence two."
|
||||||
separator = "|||"
|
separator = "|||"
|
||||||
@ -50,3 +68,35 @@ def test_spacy_text_splitter_strip_whitespace(pipeline: str) -> None:
|
|||||||
output = splitter.split_text(text)
|
output = splitter.split_text(text)
|
||||||
expected_output = [f"This is sentence one. {separator}And this is sentence two."]
|
expected_output = [f"This is sentence one. {separator}And this is sentence two."]
|
||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_nltk_text_splitter_args() -> None:
|
||||||
|
"""Test invalid arguments for NLTKTextSplitter."""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
NLTKTextSplitter(
|
||||||
|
chunk_size=80,
|
||||||
|
chunk_overlap=0,
|
||||||
|
separator="\n\n",
|
||||||
|
use_span_tokenize=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_nltk_text_splitter_with_add_start_index() -> None:
|
||||||
|
splitter = NLTKTextSplitter(
|
||||||
|
chunk_size=80,
|
||||||
|
chunk_overlap=0,
|
||||||
|
separator="",
|
||||||
|
use_span_tokenize=True,
|
||||||
|
add_start_index=True,
|
||||||
|
)
|
||||||
|
txt = (
|
||||||
|
"Innovation drives our success. "
|
||||||
|
"Collaboration fosters creative solutions. "
|
||||||
|
"Efficiency enhances data management."
|
||||||
|
)
|
||||||
|
docs = [Document(txt)]
|
||||||
|
chunks = splitter.split_documents(docs)
|
||||||
|
assert len(chunks) == 2
|
||||||
|
for chunk in chunks:
|
||||||
|
s_i = chunk.metadata["start_index"]
|
||||||
|
assert chunk.page_content == txt[s_i : s_i + len(chunk.page_content)]
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
"""Test text splitters that require an integration."""
|
"""Test text splitters that require an integration."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain_text_splitters import (
|
from langchain_text_splitters import (
|
||||||
@ -11,6 +13,15 @@ from langchain_text_splitters.sentence_transformers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def sentence_transformers() -> Any:
|
||||||
|
try:
|
||||||
|
import sentence_transformers
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("SentenceTransformers not installed.")
|
||||||
|
return sentence_transformers
|
||||||
|
|
||||||
|
|
||||||
def test_huggingface_type_check() -> None:
|
def test_huggingface_type_check() -> None:
|
||||||
"""Test that type checks are done properly on input."""
|
"""Test that type checks are done properly on input."""
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
@ -52,7 +63,7 @@ def test_token_text_splitter_from_tiktoken() -> None:
|
|||||||
assert expected_tokenizer == actual_tokenizer
|
assert expected_tokenizer == actual_tokenizer
|
||||||
|
|
||||||
|
|
||||||
def test_sentence_transformers_count_tokens() -> None:
|
def test_sentence_transformers_count_tokens(sentence_transformers: Any) -> None:
|
||||||
splitter = SentenceTransformersTokenTextSplitter(
|
splitter = SentenceTransformersTokenTextSplitter(
|
||||||
model_name="sentence-transformers/paraphrase-albert-small-v2"
|
model_name="sentence-transformers/paraphrase-albert-small-v2"
|
||||||
)
|
)
|
||||||
@ -67,7 +78,7 @@ def test_sentence_transformers_count_tokens() -> None:
|
|||||||
assert expected_token_count == token_count
|
assert expected_token_count == token_count
|
||||||
|
|
||||||
|
|
||||||
def test_sentence_transformers_split_text() -> None:
|
def test_sentence_transformers_split_text(sentence_transformers: Any) -> None:
|
||||||
splitter = SentenceTransformersTokenTextSplitter(
|
splitter = SentenceTransformersTokenTextSplitter(
|
||||||
model_name="sentence-transformers/paraphrase-albert-small-v2"
|
model_name="sentence-transformers/paraphrase-albert-small-v2"
|
||||||
)
|
)
|
||||||
@ -77,7 +88,7 @@ def test_sentence_transformers_split_text() -> None:
|
|||||||
assert expected_text_chunks == text_chunks
|
assert expected_text_chunks == text_chunks
|
||||||
|
|
||||||
|
|
||||||
def test_sentence_transformers_multiple_tokens() -> None:
|
def test_sentence_transformers_multiple_tokens(sentence_transformers: Any) -> None:
|
||||||
splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0)
|
splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0)
|
||||||
text = "Lorem "
|
text = "Lorem "
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user