mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-05 06:33:20 +00:00
langchain[patch]: inconsistent results with RecursiveCharacterTextSplitter
's add_start_index=True
(#16583)
This PR fixes issue #16579
This commit is contained in:
parent
42db96477f
commit
08d3fd7f2e
@ -141,12 +141,15 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
|||||||
_metadatas = metadatas or [{}] * len(texts)
|
_metadatas = metadatas or [{}] * len(texts)
|
||||||
documents = []
|
documents = []
|
||||||
for i, text in enumerate(texts):
|
for i, text in enumerate(texts):
|
||||||
index = -1
|
index = 0
|
||||||
|
previous_chunk_len = 0
|
||||||
for chunk in self.split_text(text):
|
for chunk in self.split_text(text):
|
||||||
metadata = copy.deepcopy(_metadatas[i])
|
metadata = copy.deepcopy(_metadatas[i])
|
||||||
if self._add_start_index:
|
if self._add_start_index:
|
||||||
index = text.find(chunk, index + 1)
|
offset = index + previous_chunk_len - self._chunk_overlap
|
||||||
|
index = text.find(chunk, max(0, offset))
|
||||||
metadata["start_index"] = index
|
metadata["start_index"] = index
|
||||||
|
previous_chunk_len = len(chunk)
|
||||||
new_doc = Document(page_content=chunk, metadata=metadata)
|
new_doc = Document(page_content=chunk, metadata=metadata)
|
||||||
documents.append(new_doc)
|
documents.append(new_doc)
|
||||||
return documents
|
return documents
|
||||||
|
@ -13,6 +13,7 @@ from langchain.text_splitter import (
|
|||||||
MarkdownHeaderTextSplitter,
|
MarkdownHeaderTextSplitter,
|
||||||
PythonCodeTextSplitter,
|
PythonCodeTextSplitter,
|
||||||
RecursiveCharacterTextSplitter,
|
RecursiveCharacterTextSplitter,
|
||||||
|
TextSplitter,
|
||||||
Tokenizer,
|
Tokenizer,
|
||||||
split_text_on_tokens,
|
split_text_on_tokens,
|
||||||
)
|
)
|
||||||
@ -169,19 +170,47 @@ def test_create_documents_with_metadata() -> None:
|
|||||||
assert docs == expected_docs
|
assert docs == expected_docs
|
||||||
|
|
||||||
|
|
||||||
def test_create_documents_with_start_index() -> None:
|
@pytest.mark.parametrize(
|
||||||
|
"splitter, text, expected_docs",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
CharacterTextSplitter(
|
||||||
|
separator=" ", chunk_size=7, chunk_overlap=3, add_start_index=True
|
||||||
|
),
|
||||||
|
"foo bar baz 123",
|
||||||
|
[
|
||||||
|
Document(page_content="foo bar", metadata={"start_index": 0}),
|
||||||
|
Document(page_content="bar baz", metadata={"start_index": 4}),
|
||||||
|
Document(page_content="baz 123", metadata={"start_index": 8}),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=6,
|
||||||
|
chunk_overlap=0,
|
||||||
|
separators=["\n\n", "\n", " ", ""],
|
||||||
|
add_start_index=True,
|
||||||
|
),
|
||||||
|
"w1 w1 w1 w1 w1 w1 w1 w1 w1",
|
||||||
|
[
|
||||||
|
Document(page_content="w1 w1", metadata={"start_index": 0}),
|
||||||
|
Document(page_content="w1 w1", metadata={"start_index": 6}),
|
||||||
|
Document(page_content="w1 w1", metadata={"start_index": 12}),
|
||||||
|
Document(page_content="w1 w1", metadata={"start_index": 18}),
|
||||||
|
Document(page_content="w1", metadata={"start_index": 24}),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_create_documents_with_start_index(
|
||||||
|
splitter: TextSplitter, text: str, expected_docs: List[Document]
|
||||||
|
) -> None:
|
||||||
"""Test create documents method."""
|
"""Test create documents method."""
|
||||||
texts = ["foo bar baz 123"]
|
docs = splitter.create_documents([text])
|
||||||
splitter = CharacterTextSplitter(
|
|
||||||
separator=" ", chunk_size=7, chunk_overlap=3, add_start_index=True
|
|
||||||
)
|
|
||||||
docs = splitter.create_documents(texts)
|
|
||||||
expected_docs = [
|
|
||||||
Document(page_content="foo bar", metadata={"start_index": 0}),
|
|
||||||
Document(page_content="bar baz", metadata={"start_index": 4}),
|
|
||||||
Document(page_content="baz 123", metadata={"start_index": 8}),
|
|
||||||
]
|
|
||||||
assert docs == expected_docs
|
assert docs == expected_docs
|
||||||
|
for doc in docs:
|
||||||
|
s_i = doc.metadata["start_index"]
|
||||||
|
assert text[s_i : s_i + len(doc.page_content)] == doc.page_content
|
||||||
|
|
||||||
|
|
||||||
def test_metadata_not_shallow() -> None:
|
def test_metadata_not_shallow() -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user