diff --git a/docs/modules/utils/combine_docs_examples/textsplitter.ipynb b/docs/modules/utils/combine_docs_examples/textsplitter.ipynb index ad6dafdd205..86c46cffaee 100644 --- a/docs/modules/utils/combine_docs_examples/textsplitter.ipynb +++ b/docs/modules/utils/combine_docs_examples/textsplitter.ipynb @@ -90,6 +90,61 @@ "print(texts[0])" ] }, + { + "cell_type": "markdown", + "id": "1be00b73", + "metadata": {}, + "source": [ + "## Recursive Character Text Splitting\n", + "Sometimes, it's not enough to split on just one character. This text splitter uses a whole list of characters and recursive splits them down until they are under the limit." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "1ac6376d", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.text_splitter import RecursiveCharacterTextSplitter" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "6787b13b", + "metadata": {}, + "outputs": [], + "source": [ + "text_splitter = RecursiveCharacterTextSplitter(\n", + " # Set a really small chunk size, just to show.\n", + " chunk_size = 100,\n", + " chunk_overlap = 20,\n", + " length_function = len,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "4f0e7d9b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet.\n", + "and the Cabinet. Justices of the Supreme Court. My fellow Americans. \n" + ] + } + ], + "source": [ + "texts = text_splitter.split_text(state_of_the_union)\n", + "print(texts[0])\n", + "print(texts[1])" + ] + }, { "cell_type": "markdown", "id": "87a71115", diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index 69c995c25e6..2012b6ec614 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -15,7 +15,6 @@ class TextSplitter(ABC): def __init__( self, - separator: str = "\n\n", chunk_size: int = 4000, chunk_overlap: int = 200, length_function: Callable[[str], int] = len, @@ -26,7 +25,6 @@ class TextSplitter(ABC): f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " f"({chunk_size}), should be smaller." ) - self._separator = separator self._chunk_size = chunk_size self._chunk_overlap = chunk_overlap self._length_function = length_function @@ -46,7 +44,7 @@ class TextSplitter(ABC): documents.append(Document(page_content=chunk, metadata=_metadatas[i])) return documents - def _merge_splits(self, splits: Iterable[str]) -> List[str]: + def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: # We now want to combine these smaller pieces into medium size # chunks to send to the LLM. docs = [] @@ -61,13 +59,18 @@ class TextSplitter(ABC): f"which is longer than the specified {self._chunk_size}" ) if len(current_doc) > 0: - docs.append(self._separator.join(current_doc)) - while total > self._chunk_overlap: + docs.append(separator.join(current_doc)) + # Keep on popping if: + # - we have a larger chunk than in the chunk overlap + # - or if we still have any chunks and the length is long + while total > self._chunk_overlap or ( + total + _len > self._chunk_size and total > 0 + ): total -= self._length_function(current_doc[0]) current_doc = current_doc[1:] current_doc.append(d) total += _len - docs.append(self._separator.join(current_doc)) + docs.append(separator.join(current_doc)) return docs @classmethod @@ -116,21 +119,74 @@ class TextSplitter(ABC): class CharacterTextSplitter(TextSplitter): """Implementation of splitting text that looks at characters.""" + def __init__(self, separator: str = "\n\n", **kwargs: Any): + """Create a new TextSplitter.""" + super().__init__(**kwargs) + self._separator = separator + def split_text(self, text: str) -> List[str]: """Split incoming text and return chunks.""" # First we naively split the large input into a bunch of smaller ones. - splits = text.split(self._separator) - return self._merge_splits(splits) + if self._separator: + splits = text.split(self._separator) + else: + splits = list(text) + return self._merge_splits(splits, self._separator) + + +class RecursiveCharacterTextSplitter(TextSplitter): + """Implementation of splitting text that looks at characters. + + Recursively tries to split by different characters to find one + that works. + """ + + def __init__(self, separators: Optional[List[str]] = None, **kwargs: Any): + """Create a new TextSplitter.""" + super().__init__(**kwargs) + self._separators = separators or ["\n\n", "\n", " ", ""] + + def split_text(self, text: str) -> List[str]: + """Split incoming text and return chunks.""" + final_chunks = [] + # Get appropriate separator to use + separator = self._separators[-1] + for _s in self._separators: + if _s == "": + separator = _s + break + if _s in text: + separator = _s + break + # Now that we have the separator, split the text + if separator: + splits = text.split(separator) + else: + splits = list(text) + # Now go merging things, recursively splitting longer texts. + _good_splits = [] + for s in splits: + if len(s) < self._chunk_size: + _good_splits.append(s) + else: + if _good_splits: + merged_text = self._merge_splits(_good_splits, separator) + final_chunks.extend(merged_text) + _good_splits = [] + other_info = self.split_text(s) + final_chunks.extend(other_info) + if _good_splits: + merged_text = self._merge_splits(_good_splits, separator) + final_chunks.extend(merged_text) + return final_chunks class NLTKTextSplitter(TextSplitter): """Implementation of splitting text that looks at sentences using NLTK.""" - def __init__( - self, separator: str = "\n\n", chunk_size: int = 4000, chunk_overlap: int = 200 - ): + def __init__(self, separator: str = "\n\n", **kwargs: Any): """Initialize the NLTK splitter.""" - super(NLTKTextSplitter, self).__init__(separator, chunk_size, chunk_overlap) + super().__init__(**kwargs) try: from nltk.tokenize import sent_tokenize @@ -139,26 +195,23 @@ class NLTKTextSplitter(TextSplitter): raise ImportError( "NLTK is not installed, please install it with `pip install nltk`." ) + self._separator = separator def split_text(self, text: str) -> List[str]: """Split incoming text and return chunks.""" # First we naively split the large input into a bunch of smaller ones. splits = self._tokenizer(text) - return self._merge_splits(splits) + return self._merge_splits(splits, self._separator) class SpacyTextSplitter(TextSplitter): """Implementation of splitting text that looks at sentences using Spacy.""" def __init__( - self, - separator: str = "\n\n", - pipeline: str = "en_core_web_sm", - chunk_size: int = 4000, - chunk_overlap: int = 200, + self, separator: str = "\n\n", pipeline: str = "en_core_web_sm", **kwargs: Any ): """Initialize the spacy text splitter.""" - super(SpacyTextSplitter, self).__init__(separator, chunk_size, chunk_overlap) + super.__init__(**kwargs) try: import spacy except ImportError: @@ -166,8 +219,9 @@ class SpacyTextSplitter(TextSplitter): "Spacy is not installed, please install it with `pip install spacy`." ) self._tokenizer = spacy.load(pipeline) + self._separator = separator def split_text(self, text: str) -> List[str]: """Split incoming text and return chunks.""" splits = (str(s) for s in self._tokenizer(text).sents) - return self._merge_splits(splits) + return self._merge_splits(splits, self._separator) diff --git a/tests/unit_tests/test_text_splitter.py b/tests/unit_tests/test_text_splitter.py index f22ad69bf94..884ca594d31 100644 --- a/tests/unit_tests/test_text_splitter.py +++ b/tests/unit_tests/test_text_splitter.py @@ -2,7 +2,10 @@ import pytest from langchain.docstore.document import Document -from langchain.text_splitter import CharacterTextSplitter +from langchain.text_splitter import ( + CharacterTextSplitter, + RecursiveCharacterTextSplitter, +) def test_character_text_splitter() -> None: @@ -23,6 +26,15 @@ def test_character_text_splitter_long() -> None: assert output == expected_output +def test_character_text_splitter_short_words_first() -> None: + """Test splitting by character count when shorter words are first.""" + text = "a a foo bar baz" + splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=1) + output = splitter.split_text(text) + expected_output = ["a a", "foo", "bar", "baz"] + assert output == expected_output + + def test_character_text_splitter_longer_words() -> None: """Test splitting by characters when splits not found easily.""" text = "foo bar baz 123" @@ -62,3 +74,33 @@ def test_create_documents_with_metadata() -> None: Document(page_content="baz", metadata={"source": "2"}), ] assert docs == expected_docs + + +def test_iterative_text_splitter() -> None: + """Test iterative text splitter.""" + text = """Hi.\n\nI'm Harrison.\n\nHow? Are? You?\nOkay then f f f f. +This is a weird text to write, but gotta test the splittingggg some how. + +Bye!\n\n-H.""" + splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=1) + output = splitter.split_text(text) + expected_output = [ + "Hi.", + "I'm", + "Harrison.", + "How? Are?", + "You?", + "Okay then f", + "f f f f.", + "This is a", + "a weird", + "text to", + "write, but", + "gotta test", + "the", + "splitting", + "gggg", + "some how.", + "Bye!\n\n-H.", + ] + assert output == expected_output