diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index 8a4aae2e916..0246e269e30 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -146,6 +146,38 @@ class CharacterTextSplitter(TextSplitter): return self._merge_splits(splits, self._separator) +class TokenTextSplitter(TextSplitter): + """Implementation of splitting text that looks at tokens.""" + + def __init__(self, encoding_name: str = "gpt2", **kwargs: Any): + """Create a new TextSplitter.""" + super().__init__(**kwargs) + try: + import tiktoken + except ImportError: + raise ValueError( + "Could not import tiktoken python package. " + "This is needed in order to for TokenTextSplitter. " + "Please it install it with `pip install tiktoken`." + ) + # create a GPT-3 encoder instance + self._tokenizer = tiktoken.get_encoding(encoding_name) + + def split_text(self, text: str) -> List[str]: + """Split incoming text and return chunks.""" + splits = [] + input_ids = self._tokenizer.encode(text) + start_idx = 0 + cur_idx = min(start_idx + self._chunk_size, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + while start_idx < len(input_ids): + splits.append(self._tokenizer.decode(chunk_ids)) + start_idx += self._chunk_size - self._chunk_overlap + cur_idx = min(start_idx + self._chunk_size, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + return splits + + class RecursiveCharacterTextSplitter(TextSplitter): """Implementation of splitting text that looks at characters. diff --git a/tests/integration_tests/test_text_splitter.py b/tests/integration_tests/test_text_splitter.py index 902705c7261..367899aa9ef 100644 --- a/tests/integration_tests/test_text_splitter.py +++ b/tests/integration_tests/test_text_splitter.py @@ -2,7 +2,7 @@ import pytest -from langchain.text_splitter import CharacterTextSplitter +from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter def test_huggingface_type_check() -> None: @@ -21,3 +21,21 @@ def test_huggingface_tokenizer() -> None: ) output = text_splitter.split_text("foo bar") assert output == ["foo", "bar"] + + +class TestTokenTextSplitter: + """Test token text splitter.""" + + def test_basic(self) -> None: + """Test no overlap.""" + splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=0) + output = splitter.split_text("abcdef" * 5) # 10 token string + expected_output = ["abcdefabcdefabc", "defabcdefabcdef"] + assert output == expected_output + + def test_overlap(self) -> None: + """Test with overlap.""" + splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=1) + output = splitter.split_text("abcdef" * 5) # 10 token string + expected_output = ["abcdefabcdefabc", "abcdefabcdefabc", "abcdef"] + assert output == expected_output