From 4a8f5cdf4b3c2bc6b118fc6b01034c739193e251 Mon Sep 17 00:00:00 2001 From: kahkeng Date: Thu, 2 Feb 2023 19:55:13 -0800 Subject: [PATCH] Add alternative token-based text splitter (#816) This does not involve a separator, and will naively chunk input text at the appropriate boundaries in token space. This is helpful if we have strict token length limits that we need to strictly follow the specified chunk size, and we can't use aggressive separators like spaces to guarantee the absence of long strings. CharacterTextSplitter will let these strings through without splitting them, which could cause overflow errors downstream. Splitting at arbitrary token boundaries is not ideal but is hopefully mitigated by having a decent overlap quantity. Also this results in chunks which has exact number of tokens desired, instead of sometimes overcounting if we concatenate shorter strings. Potentially also helps with #528. --- langchain/text_splitter.py | 32 +++++++++++++++++++ tests/integration_tests/test_text_splitter.py | 20 +++++++++++- 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index 8a4aae2e916..0246e269e30 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -146,6 +146,38 @@ class CharacterTextSplitter(TextSplitter): return self._merge_splits(splits, self._separator) +class TokenTextSplitter(TextSplitter): + """Implementation of splitting text that looks at tokens.""" + + def __init__(self, encoding_name: str = "gpt2", **kwargs: Any): + """Create a new TextSplitter.""" + super().__init__(**kwargs) + try: + import tiktoken + except ImportError: + raise ValueError( + "Could not import tiktoken python package. " + "This is needed in order to for TokenTextSplitter. " + "Please it install it with `pip install tiktoken`." + ) + # create a GPT-3 encoder instance + self._tokenizer = tiktoken.get_encoding(encoding_name) + + def split_text(self, text: str) -> List[str]: + """Split incoming text and return chunks.""" + splits = [] + input_ids = self._tokenizer.encode(text) + start_idx = 0 + cur_idx = min(start_idx + self._chunk_size, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + while start_idx < len(input_ids): + splits.append(self._tokenizer.decode(chunk_ids)) + start_idx += self._chunk_size - self._chunk_overlap + cur_idx = min(start_idx + self._chunk_size, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + return splits + + class RecursiveCharacterTextSplitter(TextSplitter): """Implementation of splitting text that looks at characters. diff --git a/tests/integration_tests/test_text_splitter.py b/tests/integration_tests/test_text_splitter.py index 902705c7261..367899aa9ef 100644 --- a/tests/integration_tests/test_text_splitter.py +++ b/tests/integration_tests/test_text_splitter.py @@ -2,7 +2,7 @@ import pytest -from langchain.text_splitter import CharacterTextSplitter +from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter def test_huggingface_type_check() -> None: @@ -21,3 +21,21 @@ def test_huggingface_tokenizer() -> None: ) output = text_splitter.split_text("foo bar") assert output == ["foo", "bar"] + + +class TestTokenTextSplitter: + """Test token text splitter.""" + + def test_basic(self) -> None: + """Test no overlap.""" + splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=0) + output = splitter.split_text("abcdef" * 5) # 10 token string + expected_output = ["abcdefabcdefabc", "defabcdefabcdef"] + assert output == expected_output + + def test_overlap(self) -> None: + """Test with overlap.""" + splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=1) + output = splitter.split_text("abcdef" * 5) # 10 token string + expected_output = ["abcdefabcdefabc", "abcdefabcdefabc", "abcdef"] + assert output == expected_output