mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-06 15:48:39 +00:00
smart text splitter (#530)
smart text splitter that iteratively tries different separators until it works!
This commit is contained in:
parent
8dfad874a2
commit
1192cc0767
@ -90,6 +90,61 @@
|
|||||||
"print(texts[0])"
|
"print(texts[0])"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "1be00b73",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Recursive Character Text Splitting\n",
|
||||||
|
"Sometimes, it's not enough to split on just one character. This text splitter uses a whole list of characters and recursive splits them down until they are under the limit."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "1ac6376d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.text_splitter import RecursiveCharacterTextSplitter"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "6787b13b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"text_splitter = RecursiveCharacterTextSplitter(\n",
|
||||||
|
" # Set a really small chunk size, just to show.\n",
|
||||||
|
" chunk_size = 100,\n",
|
||||||
|
" chunk_overlap = 20,\n",
|
||||||
|
" length_function = len,\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "4f0e7d9b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet.\n",
|
||||||
|
"and the Cabinet. Justices of the Supreme Court. My fellow Americans. \n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"texts = text_splitter.split_text(state_of_the_union)\n",
|
||||||
|
"print(texts[0])\n",
|
||||||
|
"print(texts[1])"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "87a71115",
|
"id": "87a71115",
|
||||||
|
@ -15,7 +15,6 @@ class TextSplitter(ABC):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
separator: str = "\n\n",
|
|
||||||
chunk_size: int = 4000,
|
chunk_size: int = 4000,
|
||||||
chunk_overlap: int = 200,
|
chunk_overlap: int = 200,
|
||||||
length_function: Callable[[str], int] = len,
|
length_function: Callable[[str], int] = len,
|
||||||
@ -26,7 +25,6 @@ class TextSplitter(ABC):
|
|||||||
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
|
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
|
||||||
f"({chunk_size}), should be smaller."
|
f"({chunk_size}), should be smaller."
|
||||||
)
|
)
|
||||||
self._separator = separator
|
|
||||||
self._chunk_size = chunk_size
|
self._chunk_size = chunk_size
|
||||||
self._chunk_overlap = chunk_overlap
|
self._chunk_overlap = chunk_overlap
|
||||||
self._length_function = length_function
|
self._length_function = length_function
|
||||||
@ -46,7 +44,7 @@ class TextSplitter(ABC):
|
|||||||
documents.append(Document(page_content=chunk, metadata=_metadatas[i]))
|
documents.append(Document(page_content=chunk, metadata=_metadatas[i]))
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
def _merge_splits(self, splits: Iterable[str]) -> List[str]:
|
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
|
||||||
# We now want to combine these smaller pieces into medium size
|
# We now want to combine these smaller pieces into medium size
|
||||||
# chunks to send to the LLM.
|
# chunks to send to the LLM.
|
||||||
docs = []
|
docs = []
|
||||||
@ -61,13 +59,18 @@ class TextSplitter(ABC):
|
|||||||
f"which is longer than the specified {self._chunk_size}"
|
f"which is longer than the specified {self._chunk_size}"
|
||||||
)
|
)
|
||||||
if len(current_doc) > 0:
|
if len(current_doc) > 0:
|
||||||
docs.append(self._separator.join(current_doc))
|
docs.append(separator.join(current_doc))
|
||||||
while total > self._chunk_overlap:
|
# Keep on popping if:
|
||||||
|
# - we have a larger chunk than in the chunk overlap
|
||||||
|
# - or if we still have any chunks and the length is long
|
||||||
|
while total > self._chunk_overlap or (
|
||||||
|
total + _len > self._chunk_size and total > 0
|
||||||
|
):
|
||||||
total -= self._length_function(current_doc[0])
|
total -= self._length_function(current_doc[0])
|
||||||
current_doc = current_doc[1:]
|
current_doc = current_doc[1:]
|
||||||
current_doc.append(d)
|
current_doc.append(d)
|
||||||
total += _len
|
total += _len
|
||||||
docs.append(self._separator.join(current_doc))
|
docs.append(separator.join(current_doc))
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -116,21 +119,74 @@ class TextSplitter(ABC):
|
|||||||
class CharacterTextSplitter(TextSplitter):
|
class CharacterTextSplitter(TextSplitter):
|
||||||
"""Implementation of splitting text that looks at characters."""
|
"""Implementation of splitting text that looks at characters."""
|
||||||
|
|
||||||
|
def __init__(self, separator: str = "\n\n", **kwargs: Any):
|
||||||
|
"""Create a new TextSplitter."""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._separator = separator
|
||||||
|
|
||||||
def split_text(self, text: str) -> List[str]:
|
def split_text(self, text: str) -> List[str]:
|
||||||
"""Split incoming text and return chunks."""
|
"""Split incoming text and return chunks."""
|
||||||
# First we naively split the large input into a bunch of smaller ones.
|
# First we naively split the large input into a bunch of smaller ones.
|
||||||
|
if self._separator:
|
||||||
splits = text.split(self._separator)
|
splits = text.split(self._separator)
|
||||||
return self._merge_splits(splits)
|
else:
|
||||||
|
splits = list(text)
|
||||||
|
return self._merge_splits(splits, self._separator)
|
||||||
|
|
||||||
|
|
||||||
|
class RecursiveCharacterTextSplitter(TextSplitter):
|
||||||
|
"""Implementation of splitting text that looks at characters.
|
||||||
|
|
||||||
|
Recursively tries to split by different characters to find one
|
||||||
|
that works.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, separators: Optional[List[str]] = None, **kwargs: Any):
|
||||||
|
"""Create a new TextSplitter."""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._separators = separators or ["\n\n", "\n", " ", ""]
|
||||||
|
|
||||||
|
def split_text(self, text: str) -> List[str]:
|
||||||
|
"""Split incoming text and return chunks."""
|
||||||
|
final_chunks = []
|
||||||
|
# Get appropriate separator to use
|
||||||
|
separator = self._separators[-1]
|
||||||
|
for _s in self._separators:
|
||||||
|
if _s == "":
|
||||||
|
separator = _s
|
||||||
|
break
|
||||||
|
if _s in text:
|
||||||
|
separator = _s
|
||||||
|
break
|
||||||
|
# Now that we have the separator, split the text
|
||||||
|
if separator:
|
||||||
|
splits = text.split(separator)
|
||||||
|
else:
|
||||||
|
splits = list(text)
|
||||||
|
# Now go merging things, recursively splitting longer texts.
|
||||||
|
_good_splits = []
|
||||||
|
for s in splits:
|
||||||
|
if len(s) < self._chunk_size:
|
||||||
|
_good_splits.append(s)
|
||||||
|
else:
|
||||||
|
if _good_splits:
|
||||||
|
merged_text = self._merge_splits(_good_splits, separator)
|
||||||
|
final_chunks.extend(merged_text)
|
||||||
|
_good_splits = []
|
||||||
|
other_info = self.split_text(s)
|
||||||
|
final_chunks.extend(other_info)
|
||||||
|
if _good_splits:
|
||||||
|
merged_text = self._merge_splits(_good_splits, separator)
|
||||||
|
final_chunks.extend(merged_text)
|
||||||
|
return final_chunks
|
||||||
|
|
||||||
|
|
||||||
class NLTKTextSplitter(TextSplitter):
|
class NLTKTextSplitter(TextSplitter):
|
||||||
"""Implementation of splitting text that looks at sentences using NLTK."""
|
"""Implementation of splitting text that looks at sentences using NLTK."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, separator: str = "\n\n", **kwargs: Any):
|
||||||
self, separator: str = "\n\n", chunk_size: int = 4000, chunk_overlap: int = 200
|
|
||||||
):
|
|
||||||
"""Initialize the NLTK splitter."""
|
"""Initialize the NLTK splitter."""
|
||||||
super(NLTKTextSplitter, self).__init__(separator, chunk_size, chunk_overlap)
|
super().__init__(**kwargs)
|
||||||
try:
|
try:
|
||||||
from nltk.tokenize import sent_tokenize
|
from nltk.tokenize import sent_tokenize
|
||||||
|
|
||||||
@ -139,26 +195,23 @@ class NLTKTextSplitter(TextSplitter):
|
|||||||
raise ImportError(
|
raise ImportError(
|
||||||
"NLTK is not installed, please install it with `pip install nltk`."
|
"NLTK is not installed, please install it with `pip install nltk`."
|
||||||
)
|
)
|
||||||
|
self._separator = separator
|
||||||
|
|
||||||
def split_text(self, text: str) -> List[str]:
|
def split_text(self, text: str) -> List[str]:
|
||||||
"""Split incoming text and return chunks."""
|
"""Split incoming text and return chunks."""
|
||||||
# First we naively split the large input into a bunch of smaller ones.
|
# First we naively split the large input into a bunch of smaller ones.
|
||||||
splits = self._tokenizer(text)
|
splits = self._tokenizer(text)
|
||||||
return self._merge_splits(splits)
|
return self._merge_splits(splits, self._separator)
|
||||||
|
|
||||||
|
|
||||||
class SpacyTextSplitter(TextSplitter):
|
class SpacyTextSplitter(TextSplitter):
|
||||||
"""Implementation of splitting text that looks at sentences using Spacy."""
|
"""Implementation of splitting text that looks at sentences using Spacy."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, separator: str = "\n\n", pipeline: str = "en_core_web_sm", **kwargs: Any
|
||||||
separator: str = "\n\n",
|
|
||||||
pipeline: str = "en_core_web_sm",
|
|
||||||
chunk_size: int = 4000,
|
|
||||||
chunk_overlap: int = 200,
|
|
||||||
):
|
):
|
||||||
"""Initialize the spacy text splitter."""
|
"""Initialize the spacy text splitter."""
|
||||||
super(SpacyTextSplitter, self).__init__(separator, chunk_size, chunk_overlap)
|
super.__init__(**kwargs)
|
||||||
try:
|
try:
|
||||||
import spacy
|
import spacy
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -166,8 +219,9 @@ class SpacyTextSplitter(TextSplitter):
|
|||||||
"Spacy is not installed, please install it with `pip install spacy`."
|
"Spacy is not installed, please install it with `pip install spacy`."
|
||||||
)
|
)
|
||||||
self._tokenizer = spacy.load(pipeline)
|
self._tokenizer = spacy.load(pipeline)
|
||||||
|
self._separator = separator
|
||||||
|
|
||||||
def split_text(self, text: str) -> List[str]:
|
def split_text(self, text: str) -> List[str]:
|
||||||
"""Split incoming text and return chunks."""
|
"""Split incoming text and return chunks."""
|
||||||
splits = (str(s) for s in self._tokenizer(text).sents)
|
splits = (str(s) for s in self._tokenizer(text).sents)
|
||||||
return self._merge_splits(splits)
|
return self._merge_splits(splits, self._separator)
|
||||||
|
@ -2,7 +2,10 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.text_splitter import CharacterTextSplitter
|
from langchain.text_splitter import (
|
||||||
|
CharacterTextSplitter,
|
||||||
|
RecursiveCharacterTextSplitter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_character_text_splitter() -> None:
|
def test_character_text_splitter() -> None:
|
||||||
@ -23,6 +26,15 @@ def test_character_text_splitter_long() -> None:
|
|||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_character_text_splitter_short_words_first() -> None:
|
||||||
|
"""Test splitting by character count when shorter words are first."""
|
||||||
|
text = "a a foo bar baz"
|
||||||
|
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=1)
|
||||||
|
output = splitter.split_text(text)
|
||||||
|
expected_output = ["a a", "foo", "bar", "baz"]
|
||||||
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
def test_character_text_splitter_longer_words() -> None:
|
def test_character_text_splitter_longer_words() -> None:
|
||||||
"""Test splitting by characters when splits not found easily."""
|
"""Test splitting by characters when splits not found easily."""
|
||||||
text = "foo bar baz 123"
|
text = "foo bar baz 123"
|
||||||
@ -62,3 +74,33 @@ def test_create_documents_with_metadata() -> None:
|
|||||||
Document(page_content="baz", metadata={"source": "2"}),
|
Document(page_content="baz", metadata={"source": "2"}),
|
||||||
]
|
]
|
||||||
assert docs == expected_docs
|
assert docs == expected_docs
|
||||||
|
|
||||||
|
|
||||||
|
def test_iterative_text_splitter() -> None:
|
||||||
|
"""Test iterative text splitter."""
|
||||||
|
text = """Hi.\n\nI'm Harrison.\n\nHow? Are? You?\nOkay then f f f f.
|
||||||
|
This is a weird text to write, but gotta test the splittingggg some how.
|
||||||
|
|
||||||
|
Bye!\n\n-H."""
|
||||||
|
splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=1)
|
||||||
|
output = splitter.split_text(text)
|
||||||
|
expected_output = [
|
||||||
|
"Hi.",
|
||||||
|
"I'm",
|
||||||
|
"Harrison.",
|
||||||
|
"How? Are?",
|
||||||
|
"You?",
|
||||||
|
"Okay then f",
|
||||||
|
"f f f f.",
|
||||||
|
"This is a",
|
||||||
|
"a weird",
|
||||||
|
"text to",
|
||||||
|
"write, but",
|
||||||
|
"gotta test",
|
||||||
|
"the",
|
||||||
|
"splitting",
|
||||||
|
"gggg",
|
||||||
|
"some how.",
|
||||||
|
"Bye!\n\n-H.",
|
||||||
|
]
|
||||||
|
assert output == expected_output
|
||||||
|
Loading…
Reference in New Issue
Block a user