mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 13:27:36 +00:00
Add alternative token-based text splitter (#816)
This does not involve a separator, and will naively chunk input text at the appropriate boundaries in token space. This is helpful if we have strict token length limits that we need to strictly follow the specified chunk size, and we can't use aggressive separators like spaces to guarantee the absence of long strings. CharacterTextSplitter will let these strings through without splitting them, which could cause overflow errors downstream. Splitting at arbitrary token boundaries is not ideal but is hopefully mitigated by having a decent overlap quantity. Also this results in chunks which has exact number of tokens desired, instead of sometimes overcounting if we concatenate shorter strings. Potentially also helps with #528.
This commit is contained in:
parent
523ad2e6bd
commit
4a8f5cdf4b
@ -146,6 +146,38 @@ class CharacterTextSplitter(TextSplitter):
|
||||
return self._merge_splits(splits, self._separator)
|
||||
|
||||
|
||||
class TokenTextSplitter(TextSplitter):
|
||||
"""Implementation of splitting text that looks at tokens."""
|
||||
|
||||
def __init__(self, encoding_name: str = "gpt2", **kwargs: Any):
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import tiktoken python package. "
|
||||
"This is needed in order to for TokenTextSplitter. "
|
||||
"Please it install it with `pip install tiktoken`."
|
||||
)
|
||||
# create a GPT-3 encoder instance
|
||||
self._tokenizer = tiktoken.get_encoding(encoding_name)
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
splits = []
|
||||
input_ids = self._tokenizer.encode(text)
|
||||
start_idx = 0
|
||||
cur_idx = min(start_idx + self._chunk_size, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
while start_idx < len(input_ids):
|
||||
splits.append(self._tokenizer.decode(chunk_ids))
|
||||
start_idx += self._chunk_size - self._chunk_overlap
|
||||
cur_idx = min(start_idx + self._chunk_size, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
return splits
|
||||
|
||||
|
||||
class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
"""Implementation of splitting text that looks at characters.
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter
|
||||
|
||||
|
||||
def test_huggingface_type_check() -> None:
|
||||
@ -21,3 +21,21 @@ def test_huggingface_tokenizer() -> None:
|
||||
)
|
||||
output = text_splitter.split_text("foo bar")
|
||||
assert output == ["foo", "bar"]
|
||||
|
||||
|
||||
class TestTokenTextSplitter:
|
||||
"""Test token text splitter."""
|
||||
|
||||
def test_basic(self) -> None:
|
||||
"""Test no overlap."""
|
||||
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=0)
|
||||
output = splitter.split_text("abcdef" * 5) # 10 token string
|
||||
expected_output = ["abcdefabcdefabc", "defabcdefabcdef"]
|
||||
assert output == expected_output
|
||||
|
||||
def test_overlap(self) -> None:
|
||||
"""Test with overlap."""
|
||||
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=1)
|
||||
output = splitter.split_text("abcdef" * 5) # 10 token string
|
||||
expected_output = ["abcdefabcdefabc", "abcdefabcdefabc", "abcdef"]
|
||||
assert output == expected_output
|
||||
|
Loading…
Reference in New Issue
Block a user