mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-03 03:38:06 +00:00
refactor: extract token text splitter function (#5179)
# Token text splitter for sentence transformers The current TokenTextSplitter only works with OpenAi models via the `tiktoken` package. This is not clear from the name `TokenTextSplitter`. In this (first PR) a token based text splitter for sentence transformer models is added. In the future I think we should work towards injecting a tokenizer into the TokenTextSplitter to make ti more flexible. Could perhaps be reviewed by @dev2049 --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
26ec845921
commit
8d9e9e013c
@ -0,0 +1,131 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "73dbcdb9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# SentenceTransformersTokenTextSplitter\n",
|
||||
"\n",
|
||||
"This notebook demonstrates how to use the `SentenceTransformersTokenTextSplitter` text splitter.\n",
|
||||
"\n",
|
||||
"Language models have a token limit. You should not exceed the token limit. When you split your text into chunks it is therefore a good idea to count the number of tokens. There are many tokenizers. When you count tokens in your text you should use the same tokenizer as used in the language model. \n",
|
||||
"\n",
|
||||
"The `SentenceTransformersTokenTextSplitter` is a specialized text splitter for use with the sentence-transformer models. The default behaviour is to split the text into chunks that fit the token window of the sentence transformer model that you would like to use."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "9dd5419e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.text_splitter import SentenceTransformersTokenTextSplitter"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "b43e5d54",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0)\n",
|
||||
"text = \"Lorem \""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "1df84cb4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"count_start_and_stop_tokens = 2\n",
|
||||
"text_token_count = splitter.count_tokens(text=text) - count_start_and_stop_tokens\n",
|
||||
"print(text_token_count)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "d7ad2213",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tokens in text to split: 514\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"token_multiplier = splitter.maximum_tokens_per_chunk // text_token_count + 1\n",
|
||||
"\n",
|
||||
"# `text_to_split` does not fit in a single chunk\n",
|
||||
"text_to_split = text * token_multiplier\n",
|
||||
"\n",
|
||||
"print(f\"tokens in text to split: {splitter.count_tokens(text=text_to_split)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "818aea04",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lorem\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"text_chunks = splitter.split_text(text=text_to_split)\n",
|
||||
"\n",
|
||||
"print(text_chunks[1])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e9ba4f23",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -5,6 +5,7 @@ import copy
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
@ -244,6 +245,31 @@ class CharacterTextSplitter(TextSplitter):
|
||||
return self._merge_splits(splits, _separator)
|
||||
|
||||
|
||||
# should be in newer Python versions (3.10+)
|
||||
# @dataclass(frozen=True, kw_only=True, slots=True)
|
||||
@dataclass(frozen=True)
|
||||
class Tokenizer:
|
||||
chunk_overlap: int
|
||||
tokens_per_chunk: int
|
||||
decode: Callable[[list[int]], str]
|
||||
encode: Callable[[str], List[int]]
|
||||
|
||||
|
||||
def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
splits = []
|
||||
input_ids = tokenizer.encode(text)
|
||||
start_idx = 0
|
||||
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
while start_idx < len(input_ids):
|
||||
splits.append(tokenizer.decode(chunk_ids))
|
||||
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
|
||||
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
return splits
|
||||
|
||||
|
||||
class TokenTextSplitter(TextSplitter):
|
||||
"""Implementation of splitting text that looks at tokens."""
|
||||
|
||||
@ -275,22 +301,84 @@ class TokenTextSplitter(TextSplitter):
|
||||
self._disallowed_special = disallowed_special
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
splits = []
|
||||
input_ids = self._tokenizer.encode(
|
||||
text,
|
||||
allowed_special=self._allowed_special,
|
||||
disallowed_special=self._disallowed_special,
|
||||
def _encode(_text: str) -> List[int]:
|
||||
return self._tokenizer.encode(
|
||||
_text,
|
||||
allowed_special=self._allowed_special,
|
||||
disallowed_special=self._disallowed_special,
|
||||
)
|
||||
|
||||
tokenizer = Tokenizer(
|
||||
chunk_overlap=self._chunk_overlap,
|
||||
tokens_per_chunk=self._chunk_size,
|
||||
decode=self._tokenizer.decode,
|
||||
encode=_encode,
|
||||
)
|
||||
start_idx = 0
|
||||
cur_idx = min(start_idx + self._chunk_size, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
while start_idx < len(input_ids):
|
||||
splits.append(self._tokenizer.decode(chunk_ids))
|
||||
start_idx += self._chunk_size - self._chunk_overlap
|
||||
cur_idx = min(start_idx + self._chunk_size, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
return splits
|
||||
|
||||
return split_text_on_tokens(text=text, tokenizer=tokenizer)
|
||||
|
||||
|
||||
class SentenceTransformersTokenTextSplitter(TextSplitter):
|
||||
"""Implementation of splitting text that looks at tokens."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_overlap: int = 50,
|
||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
tokens_per_chunk: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs, chunk_overlap=chunk_overlap)
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
self.model_name = model_name
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk)
|
||||
|
||||
def _initialize_chunk_configuration(
|
||||
self, *, tokens_per_chunk: Optional[int]
|
||||
) -> None:
|
||||
self.maximum_tokens_per_chunk = self.tokenizer.max_len_single_sentence
|
||||
|
||||
if tokens_per_chunk is None:
|
||||
self.tokens_per_chunk = self.maximum_tokens_per_chunk
|
||||
else:
|
||||
self.tokens_per_chunk = tokens_per_chunk
|
||||
|
||||
if self.tokens_per_chunk > self.maximum_tokens_per_chunk:
|
||||
raise ValueError(
|
||||
f"The token limit of the models '{self.model_name}'"
|
||||
f" is: {self.maximum_tokens_per_chunk}."
|
||||
f" Argument tokens_per_chunk={self.tokens_per_chunk}"
|
||||
f" > maximum token limit."
|
||||
)
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
def encode_strip_start_and_stop_token_ids(text: str) -> List[int]:
|
||||
return self._encode(text)[1:-1]
|
||||
|
||||
tokenizer = Tokenizer(
|
||||
chunk_overlap=self._chunk_overlap,
|
||||
tokens_per_chunk=self.tokens_per_chunk,
|
||||
decode=self.tokenizer.decode,
|
||||
encode=encode_strip_start_and_stop_token_ids,
|
||||
)
|
||||
|
||||
return split_text_on_tokens(text=text, tokenizer=tokenizer)
|
||||
|
||||
def count_tokens(self, *, text: str) -> int:
|
||||
return len(self._encode(text))
|
||||
|
||||
_max_length_equal_32_bit_integer = 2**32
|
||||
|
||||
def _encode(self, text: str) -> List[int]:
|
||||
token_ids_with_start_and_end_token_ids = self.tokenizer.encode(
|
||||
text,
|
||||
max_length=self._max_length_equal_32_bit_integer,
|
||||
truncation="do_not_truncate",
|
||||
)
|
||||
return token_ids_with_start_and_end_token_ids
|
||||
|
||||
|
||||
class Language(str, Enum):
|
||||
|
@ -2,7 +2,11 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter
|
||||
from langchain.text_splitter import (
|
||||
CharacterTextSplitter,
|
||||
SentenceTransformersTokenTextSplitter,
|
||||
TokenTextSplitter,
|
||||
)
|
||||
|
||||
|
||||
def test_huggingface_type_check() -> None:
|
||||
@ -44,3 +48,45 @@ def test_token_text_splitter_from_tiktoken() -> None:
|
||||
expected_tokenizer = "cl100k_base"
|
||||
actual_tokenizer = splitter._tokenizer.name
|
||||
assert expected_tokenizer == actual_tokenizer
|
||||
|
||||
|
||||
def test_sentence_transformers_count_tokens() -> None:
|
||||
splitter = SentenceTransformersTokenTextSplitter(
|
||||
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||
)
|
||||
text = "Lorem ipsum"
|
||||
|
||||
token_count = splitter.count_tokens(text=text)
|
||||
|
||||
expected_start_stop_token_count = 2
|
||||
expected_text_token_count = 2
|
||||
expected_token_count = expected_start_stop_token_count + expected_text_token_count
|
||||
|
||||
assert expected_token_count == token_count
|
||||
|
||||
|
||||
def test_sentence_transformers_split_text() -> None:
|
||||
splitter = SentenceTransformersTokenTextSplitter(
|
||||
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||
)
|
||||
text = "Lorem ipsum"
|
||||
text_chunks = splitter.split_text(text=text)
|
||||
expected_text_chunks = [text]
|
||||
assert expected_text_chunks == text_chunks
|
||||
|
||||
|
||||
def test_sentence_transformers_multiple_tokens() -> None:
|
||||
splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0)
|
||||
text = "Lorem "
|
||||
|
||||
count_start_and_end_tokens = 2
|
||||
text_token_count = splitter.count_tokens(text=text) - count_start_and_end_tokens
|
||||
token_multiplier = splitter.maximum_tokens_per_chunk // text_token_count + 1
|
||||
text_chunks = splitter.split_text(text=text * token_multiplier)
|
||||
|
||||
expected_number_of_chunks = 2
|
||||
|
||||
assert expected_number_of_chunks == len(text_chunks)
|
||||
actual = splitter.count_tokens(text=text_chunks[1]) - count_start_and_end_tokens
|
||||
expected = token_multiplier * text_token_count - splitter.maximum_tokens_per_chunk
|
||||
assert expected == actual
|
||||
|
Loading…
Reference in New Issue
Block a user