mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 12:48:12 +00:00
text-splitters: Inconsistent results with NLTKTextSplitter
's add_start_index=True
(#27782)
This PR closes #27781 # Problem The current implementation of `NLTKTextSplitter` is using `sent_tokenize`. However, this `sent_tokenize` doesn't handle chars between 2 tokenized sentences... hence, this behavior throws errors when we are using `add_start_index=True`, as described in issue #27781. In particular: ```python from nltk.tokenize import sent_tokenize output1 = sent_tokenize("Innovation drives our success. Collaboration fosters creative solutions. Efficiency enhances data management.", language="english") print(output1) output2 = sent_tokenize("Innovation drives our success. Collaboration fosters creative solutions. Efficiency enhances data management.", language="english") print(output2) >>> ['Innovation drives our success.', 'Collaboration fosters creative solutions.', 'Efficiency enhances data management.'] >>> ['Innovation drives our success.', 'Collaboration fosters creative solutions.', 'Efficiency enhances data management.'] ``` # Solution With this new `use_span_tokenize` parameter, we can use NLTK to create sentences (with `span_tokenize`), but also add extra chars to be sure that we still can map the chunks to the original text. --------- Co-authored-by: Erick Friis <erick@langchain.dev> Co-authored-by: Erick Friis <erickfriis@gmail.com>
This commit is contained in:
parent
d262d41cc0
commit
b2102b8cc4
@ -9,6 +9,9 @@ TEST_FILE ?= tests/unit_tests/
|
||||
test tests:
|
||||
poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE)
|
||||
|
||||
integration_test integration_tests:
|
||||
poetry run pytest tests/integration_tests/
|
||||
|
||||
test_watch:
|
||||
poetry run ptw --snapshot-update --now . -- -vv -x tests/unit_tests
|
||||
|
||||
|
@ -9,11 +9,26 @@ class NLTKTextSplitter(TextSplitter):
|
||||
"""Splitting text using NLTK package."""
|
||||
|
||||
def __init__(
|
||||
self, separator: str = "\n\n", language: str = "english", **kwargs: Any
|
||||
self,
|
||||
separator: str = "\n\n",
|
||||
language: str = "english",
|
||||
*,
|
||||
use_span_tokenize: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the NLTK splitter."""
|
||||
super().__init__(**kwargs)
|
||||
self._separator = separator
|
||||
self._language = language
|
||||
self._use_span_tokenize = use_span_tokenize
|
||||
if self._use_span_tokenize and self._separator != "":
|
||||
raise ValueError("When use_span_tokenize is True, separator should be ''")
|
||||
try:
|
||||
if self._use_span_tokenize:
|
||||
from nltk.tokenize import _get_punkt_tokenizer
|
||||
|
||||
self._tokenizer = _get_punkt_tokenizer(self._language)
|
||||
else:
|
||||
from nltk.tokenize import sent_tokenize
|
||||
|
||||
self._tokenizer = sent_tokenize
|
||||
@ -21,11 +36,20 @@ class NLTKTextSplitter(TextSplitter):
|
||||
raise ImportError(
|
||||
"NLTK is not installed, please install it with `pip install nltk`."
|
||||
)
|
||||
self._separator = separator
|
||||
self._language = language
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
# First we naively split the large input into a bunch of smaller ones.
|
||||
if self._use_span_tokenize:
|
||||
spans = list(self._tokenizer.span_tokenize(text))
|
||||
splits = []
|
||||
for i, (start, end) in enumerate(spans):
|
||||
if i > 0:
|
||||
prev_end = spans[i - 1][1]
|
||||
sentence = text[prev_end:start] + text[start:end]
|
||||
else:
|
||||
sentence = text[start:end]
|
||||
splits.append(sentence)
|
||||
else:
|
||||
splits = self._tokenizer(text, language=self._language)
|
||||
return self._merge_splits(splits, self._separator)
|
||||
|
@ -22,7 +22,7 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import sentence_transformer python package. "
|
||||
"Could not import sentence_transformers python package. "
|
||||
"This is needed in order to for SentenceTransformersTokenTextSplitter. "
|
||||
"Please install it with `pip install sentence-transformers`."
|
||||
)
|
||||
|
1720
libs/text-splitters/poetry.lock
generated
1720
libs/text-splitters/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,5 @@
|
||||
[build-system]
|
||||
requires = [ "poetry-core>=1.0.0",]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.poetry]
|
||||
@ -14,7 +14,20 @@ repository = "https://github.com/langchain-ai/langchain"
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
[[tool.mypy.overrides]]
|
||||
module = [ "transformers", "sentence_transformers", "nltk.tokenize", "konlpy.tag", "bs4", "pytest", "spacy", "spacy.lang.en", "numpy",]
|
||||
module = [
|
||||
"transformers",
|
||||
"sentence_transformers",
|
||||
"nltk.tokenize",
|
||||
"konlpy.tag",
|
||||
"bs4",
|
||||
"pytest",
|
||||
"spacy",
|
||||
"spacy.lang.en",
|
||||
"numpy",
|
||||
"nltk",
|
||||
"spacy.cli",
|
||||
"torch",
|
||||
]
|
||||
ignore_missing_imports = "True"
|
||||
|
||||
[tool.poetry.urls]
|
||||
@ -26,15 +39,18 @@ python = ">=3.9,<4.0"
|
||||
langchain-core = "^0.3.25"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [ "E", "F", "I", "T201", "D",]
|
||||
ignore = [ "D100",]
|
||||
select = ["E", "F", "I", "T201", "D"]
|
||||
ignore = ["D100"]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [ "tests/*",]
|
||||
omit = ["tests/*"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--strict-markers --strict-config --durations=5"
|
||||
markers = [ "requires: mark tests as requiring a specific library", "compile: mark placeholder test used to compile integration tests without running them",]
|
||||
markers = [
|
||||
"requires: mark tests as requiring a specific library",
|
||||
"compile: mark placeholder test used to compile integration tests without running them",
|
||||
]
|
||||
asyncio_mode = "auto"
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
@ -53,19 +69,17 @@ optional = true
|
||||
convention = "google"
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/**" = [ "D",]
|
||||
"tests/**" = ["D"]
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.5"
|
||||
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^1.10"
|
||||
lxml-stubs = "^0.5.1"
|
||||
types-requests = "^2.31.0.20240218"
|
||||
tiktoken = "^0.8.0"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
jupyter = "^1.0.0"
|
||||
|
||||
@ -78,20 +92,23 @@ pytest-watcher = "^0.3.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
pytest-socket = "^0.7.0"
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
|
||||
spacy = { version = "*", python = "<3.13" }
|
||||
nltk = "^3.9.1"
|
||||
transformers = "^4.47.0"
|
||||
sentence-transformers = { version = ">=2.6.0", python = "<3.13" }
|
||||
|
||||
[tool.poetry.group.lint.dependencies.langchain-core]
|
||||
path = "../core"
|
||||
develop = true
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies.langchain-core]
|
||||
path = "../core"
|
||||
develop = true
|
||||
|
||||
|
||||
[tool.poetry.group.test.dependencies.langchain-core]
|
||||
path = "../core"
|
||||
develop = true
|
||||
|
@ -1,18 +1,36 @@
|
||||
"""Test text splitting functionality using NLTK and Spacy based sentence splitters."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import nltk
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_text_splitters.nltk import NLTKTextSplitter
|
||||
from langchain_text_splitters.spacy import SpacyTextSplitter
|
||||
|
||||
|
||||
def setup_module() -> None:
|
||||
nltk.download("punkt_tab")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def spacy() -> Any:
|
||||
try:
|
||||
import spacy
|
||||
except ImportError:
|
||||
pytest.skip("Spacy not installed.")
|
||||
spacy.cli.download("en_core_web_sm") # type: ignore
|
||||
return spacy
|
||||
|
||||
|
||||
def test_nltk_text_splitting_args() -> None:
|
||||
"""Test invalid arguments."""
|
||||
with pytest.raises(ValueError):
|
||||
NLTKTextSplitter(chunk_size=2, chunk_overlap=4)
|
||||
|
||||
|
||||
def test_spacy_text_splitting_args() -> None:
|
||||
def test_spacy_text_splitting_args(spacy: Any) -> None:
|
||||
"""Test invalid arguments."""
|
||||
with pytest.raises(ValueError):
|
||||
SpacyTextSplitter(chunk_size=2, chunk_overlap=4)
|
||||
@ -29,7 +47,7 @@ def test_nltk_text_splitter() -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pipeline", ["sentencizer", "en_core_web_sm"])
|
||||
def test_spacy_text_splitter(pipeline: str) -> None:
|
||||
def test_spacy_text_splitter(pipeline: str, spacy: Any) -> None:
|
||||
"""Test splitting by sentence using Spacy."""
|
||||
text = "This is sentence one. And this is sentence two."
|
||||
separator = "|||"
|
||||
@ -40,7 +58,7 @@ def test_spacy_text_splitter(pipeline: str) -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pipeline", ["sentencizer", "en_core_web_sm"])
|
||||
def test_spacy_text_splitter_strip_whitespace(pipeline: str) -> None:
|
||||
def test_spacy_text_splitter_strip_whitespace(pipeline: str, spacy: Any) -> None:
|
||||
"""Test splitting by sentence using Spacy."""
|
||||
text = "This is sentence one. And this is sentence two."
|
||||
separator = "|||"
|
||||
@ -50,3 +68,35 @@ def test_spacy_text_splitter_strip_whitespace(pipeline: str) -> None:
|
||||
output = splitter.split_text(text)
|
||||
expected_output = [f"This is sentence one. {separator}And this is sentence two."]
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_nltk_text_splitter_args() -> None:
|
||||
"""Test invalid arguments for NLTKTextSplitter."""
|
||||
with pytest.raises(ValueError):
|
||||
NLTKTextSplitter(
|
||||
chunk_size=80,
|
||||
chunk_overlap=0,
|
||||
separator="\n\n",
|
||||
use_span_tokenize=True,
|
||||
)
|
||||
|
||||
|
||||
def test_nltk_text_splitter_with_add_start_index() -> None:
|
||||
splitter = NLTKTextSplitter(
|
||||
chunk_size=80,
|
||||
chunk_overlap=0,
|
||||
separator="",
|
||||
use_span_tokenize=True,
|
||||
add_start_index=True,
|
||||
)
|
||||
txt = (
|
||||
"Innovation drives our success. "
|
||||
"Collaboration fosters creative solutions. "
|
||||
"Efficiency enhances data management."
|
||||
)
|
||||
docs = [Document(txt)]
|
||||
chunks = splitter.split_documents(docs)
|
||||
assert len(chunks) == 2
|
||||
for chunk in chunks:
|
||||
s_i = chunk.metadata["start_index"]
|
||||
assert chunk.page_content == txt[s_i : s_i + len(chunk.page_content)]
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""Test text splitters that require an integration."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_text_splitters import (
|
||||
@ -11,6 +13,15 @@ from langchain_text_splitters.sentence_transformers import (
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def sentence_transformers() -> Any:
|
||||
try:
|
||||
import sentence_transformers
|
||||
except ImportError:
|
||||
pytest.skip("SentenceTransformers not installed.")
|
||||
return sentence_transformers
|
||||
|
||||
|
||||
def test_huggingface_type_check() -> None:
|
||||
"""Test that type checks are done properly on input."""
|
||||
with pytest.raises(ValueError):
|
||||
@ -52,7 +63,7 @@ def test_token_text_splitter_from_tiktoken() -> None:
|
||||
assert expected_tokenizer == actual_tokenizer
|
||||
|
||||
|
||||
def test_sentence_transformers_count_tokens() -> None:
|
||||
def test_sentence_transformers_count_tokens(sentence_transformers: Any) -> None:
|
||||
splitter = SentenceTransformersTokenTextSplitter(
|
||||
model_name="sentence-transformers/paraphrase-albert-small-v2"
|
||||
)
|
||||
@ -67,7 +78,7 @@ def test_sentence_transformers_count_tokens() -> None:
|
||||
assert expected_token_count == token_count
|
||||
|
||||
|
||||
def test_sentence_transformers_split_text() -> None:
|
||||
def test_sentence_transformers_split_text(sentence_transformers: Any) -> None:
|
||||
splitter = SentenceTransformersTokenTextSplitter(
|
||||
model_name="sentence-transformers/paraphrase-albert-small-v2"
|
||||
)
|
||||
@ -77,7 +88,7 @@ def test_sentence_transformers_split_text() -> None:
|
||||
assert expected_text_chunks == text_chunks
|
||||
|
||||
|
||||
def test_sentence_transformers_multiple_tokens() -> None:
|
||||
def test_sentence_transformers_multiple_tokens(sentence_transformers: Any) -> None:
|
||||
splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0)
|
||||
text = "Lorem "
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user