mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-02 13:55:42 +00:00
smart text splitter (#530)
smart text splitter that iteratively tries different separators until it works!
This commit is contained in:
parent
8dfad874a2
commit
1192cc0767
@ -90,6 +90,61 @@
|
||||
"print(texts[0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1be00b73",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Recursive Character Text Splitting\n",
|
||||
"Sometimes, it's not enough to split on just one character. This text splitter uses a whole list of characters and recursive splits them down until they are under the limit."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "1ac6376d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.text_splitter import RecursiveCharacterTextSplitter"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "6787b13b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"text_splitter = RecursiveCharacterTextSplitter(\n",
|
||||
" # Set a really small chunk size, just to show.\n",
|
||||
" chunk_size = 100,\n",
|
||||
" chunk_overlap = 20,\n",
|
||||
" length_function = len,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "4f0e7d9b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet.\n",
|
||||
"and the Cabinet. Justices of the Supreme Court. My fellow Americans. \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"texts = text_splitter.split_text(state_of_the_union)\n",
|
||||
"print(texts[0])\n",
|
||||
"print(texts[1])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "87a71115",
|
||||
|
@ -15,7 +15,6 @@ class TextSplitter(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
separator: str = "\n\n",
|
||||
chunk_size: int = 4000,
|
||||
chunk_overlap: int = 200,
|
||||
length_function: Callable[[str], int] = len,
|
||||
@ -26,7 +25,6 @@ class TextSplitter(ABC):
|
||||
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
|
||||
f"({chunk_size}), should be smaller."
|
||||
)
|
||||
self._separator = separator
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_overlap = chunk_overlap
|
||||
self._length_function = length_function
|
||||
@ -46,7 +44,7 @@ class TextSplitter(ABC):
|
||||
documents.append(Document(page_content=chunk, metadata=_metadatas[i]))
|
||||
return documents
|
||||
|
||||
def _merge_splits(self, splits: Iterable[str]) -> List[str]:
|
||||
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
|
||||
# We now want to combine these smaller pieces into medium size
|
||||
# chunks to send to the LLM.
|
||||
docs = []
|
||||
@ -61,13 +59,18 @@ class TextSplitter(ABC):
|
||||
f"which is longer than the specified {self._chunk_size}"
|
||||
)
|
||||
if len(current_doc) > 0:
|
||||
docs.append(self._separator.join(current_doc))
|
||||
while total > self._chunk_overlap:
|
||||
docs.append(separator.join(current_doc))
|
||||
# Keep on popping if:
|
||||
# - we have a larger chunk than in the chunk overlap
|
||||
# - or if we still have any chunks and the length is long
|
||||
while total > self._chunk_overlap or (
|
||||
total + _len > self._chunk_size and total > 0
|
||||
):
|
||||
total -= self._length_function(current_doc[0])
|
||||
current_doc = current_doc[1:]
|
||||
current_doc.append(d)
|
||||
total += _len
|
||||
docs.append(self._separator.join(current_doc))
|
||||
docs.append(separator.join(current_doc))
|
||||
return docs
|
||||
|
||||
@classmethod
|
||||
@ -116,21 +119,74 @@ class TextSplitter(ABC):
|
||||
class CharacterTextSplitter(TextSplitter):
|
||||
"""Implementation of splitting text that looks at characters."""
|
||||
|
||||
def __init__(self, separator: str = "\n\n", **kwargs: Any):
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
self._separator = separator
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
# First we naively split the large input into a bunch of smaller ones.
|
||||
splits = text.split(self._separator)
|
||||
return self._merge_splits(splits)
|
||||
if self._separator:
|
||||
splits = text.split(self._separator)
|
||||
else:
|
||||
splits = list(text)
|
||||
return self._merge_splits(splits, self._separator)
|
||||
|
||||
|
||||
class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
"""Implementation of splitting text that looks at characters.
|
||||
|
||||
Recursively tries to split by different characters to find one
|
||||
that works.
|
||||
"""
|
||||
|
||||
def __init__(self, separators: Optional[List[str]] = None, **kwargs: Any):
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
self._separators = separators or ["\n\n", "\n", " ", ""]
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
final_chunks = []
|
||||
# Get appropriate separator to use
|
||||
separator = self._separators[-1]
|
||||
for _s in self._separators:
|
||||
if _s == "":
|
||||
separator = _s
|
||||
break
|
||||
if _s in text:
|
||||
separator = _s
|
||||
break
|
||||
# Now that we have the separator, split the text
|
||||
if separator:
|
||||
splits = text.split(separator)
|
||||
else:
|
||||
splits = list(text)
|
||||
# Now go merging things, recursively splitting longer texts.
|
||||
_good_splits = []
|
||||
for s in splits:
|
||||
if len(s) < self._chunk_size:
|
||||
_good_splits.append(s)
|
||||
else:
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, separator)
|
||||
final_chunks.extend(merged_text)
|
||||
_good_splits = []
|
||||
other_info = self.split_text(s)
|
||||
final_chunks.extend(other_info)
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, separator)
|
||||
final_chunks.extend(merged_text)
|
||||
return final_chunks
|
||||
|
||||
|
||||
class NLTKTextSplitter(TextSplitter):
|
||||
"""Implementation of splitting text that looks at sentences using NLTK."""
|
||||
|
||||
def __init__(
|
||||
self, separator: str = "\n\n", chunk_size: int = 4000, chunk_overlap: int = 200
|
||||
):
|
||||
def __init__(self, separator: str = "\n\n", **kwargs: Any):
|
||||
"""Initialize the NLTK splitter."""
|
||||
super(NLTKTextSplitter, self).__init__(separator, chunk_size, chunk_overlap)
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
from nltk.tokenize import sent_tokenize
|
||||
|
||||
@ -139,26 +195,23 @@ class NLTKTextSplitter(TextSplitter):
|
||||
raise ImportError(
|
||||
"NLTK is not installed, please install it with `pip install nltk`."
|
||||
)
|
||||
self._separator = separator
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
# First we naively split the large input into a bunch of smaller ones.
|
||||
splits = self._tokenizer(text)
|
||||
return self._merge_splits(splits)
|
||||
return self._merge_splits(splits, self._separator)
|
||||
|
||||
|
||||
class SpacyTextSplitter(TextSplitter):
|
||||
"""Implementation of splitting text that looks at sentences using Spacy."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
separator: str = "\n\n",
|
||||
pipeline: str = "en_core_web_sm",
|
||||
chunk_size: int = 4000,
|
||||
chunk_overlap: int = 200,
|
||||
self, separator: str = "\n\n", pipeline: str = "en_core_web_sm", **kwargs: Any
|
||||
):
|
||||
"""Initialize the spacy text splitter."""
|
||||
super(SpacyTextSplitter, self).__init__(separator, chunk_size, chunk_overlap)
|
||||
super.__init__(**kwargs)
|
||||
try:
|
||||
import spacy
|
||||
except ImportError:
|
||||
@ -166,8 +219,9 @@ class SpacyTextSplitter(TextSplitter):
|
||||
"Spacy is not installed, please install it with `pip install spacy`."
|
||||
)
|
||||
self._tokenizer = spacy.load(pipeline)
|
||||
self._separator = separator
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
splits = (str(s) for s in self._tokenizer(text).sents)
|
||||
return self._merge_splits(splits)
|
||||
return self._merge_splits(splits, self._separator)
|
||||
|
@ -2,7 +2,10 @@
|
||||
import pytest
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain.text_splitter import (
|
||||
CharacterTextSplitter,
|
||||
RecursiveCharacterTextSplitter,
|
||||
)
|
||||
|
||||
|
||||
def test_character_text_splitter() -> None:
|
||||
@ -23,6 +26,15 @@ def test_character_text_splitter_long() -> None:
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_character_text_splitter_short_words_first() -> None:
|
||||
"""Test splitting by character count when shorter words are first."""
|
||||
text = "a a foo bar baz"
|
||||
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=1)
|
||||
output = splitter.split_text(text)
|
||||
expected_output = ["a a", "foo", "bar", "baz"]
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_character_text_splitter_longer_words() -> None:
|
||||
"""Test splitting by characters when splits not found easily."""
|
||||
text = "foo bar baz 123"
|
||||
@ -62,3 +74,33 @@ def test_create_documents_with_metadata() -> None:
|
||||
Document(page_content="baz", metadata={"source": "2"}),
|
||||
]
|
||||
assert docs == expected_docs
|
||||
|
||||
|
||||
def test_iterative_text_splitter() -> None:
|
||||
"""Test iterative text splitter."""
|
||||
text = """Hi.\n\nI'm Harrison.\n\nHow? Are? You?\nOkay then f f f f.
|
||||
This is a weird text to write, but gotta test the splittingggg some how.
|
||||
|
||||
Bye!\n\n-H."""
|
||||
splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=1)
|
||||
output = splitter.split_text(text)
|
||||
expected_output = [
|
||||
"Hi.",
|
||||
"I'm",
|
||||
"Harrison.",
|
||||
"How? Are?",
|
||||
"You?",
|
||||
"Okay then f",
|
||||
"f f f f.",
|
||||
"This is a",
|
||||
"a weird",
|
||||
"text to",
|
||||
"write, but",
|
||||
"gotta test",
|
||||
"the",
|
||||
"splitting",
|
||||
"gggg",
|
||||
"some how.",
|
||||
"Bye!\n\n-H.",
|
||||
]
|
||||
assert output == expected_output
|
||||
|
Loading…
Reference in New Issue
Block a user