"""Sentence transformers text splitter.""" from __future__ import annotations from typing import Any, cast from langchain_text_splitters.base import TextSplitter, Tokenizer, split_text_on_tokens try: # Type ignores needed as long as sentence-transformers doesn't support Python 3.14. from sentence_transformers import ( # type: ignore[import-not-found, unused-ignore] SentenceTransformer, ) _HAS_SENTENCE_TRANSFORMERS = True except ImportError: _HAS_SENTENCE_TRANSFORMERS = False class SentenceTransformersTokenTextSplitter(TextSplitter): """Splitting text to tokens using sentence model tokenizer.""" def __init__( self, chunk_overlap: int = 50, model_name: str = "sentence-transformers/all-mpnet-base-v2", tokens_per_chunk: int | None = None, model_kwargs: dict[str, Any] | None = None, **kwargs: Any, ) -> None: """Create a new `TextSplitter`. Args: chunk_overlap: The number of tokens to overlap between chunks. model_name: The name of the sentence transformer model to use. tokens_per_chunk: The number of tokens per chunk. If `None`, uses the maximum tokens allowed by the model. model_kwargs: Additional parameters for model initialization. Parameters of sentence_transformers.SentenceTransformer can be used. Raises: ImportError: If the `sentence_transformers` package is not installed. """ super().__init__(**kwargs, chunk_overlap=chunk_overlap) if not _HAS_SENTENCE_TRANSFORMERS: msg = ( "Could not import sentence_transformers python package. " "This is needed in order to use SentenceTransformersTokenTextSplitter. " "Please install it with `pip install sentence-transformers`." ) raise ImportError(msg) self.model_name = model_name self._model = SentenceTransformer(self.model_name, **(model_kwargs or {})) self.tokenizer = self._model.tokenizer self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk) def _initialize_chunk_configuration(self, *, tokens_per_chunk: int | None) -> None: self.maximum_tokens_per_chunk = self._model.max_seq_length if tokens_per_chunk is None: self.tokens_per_chunk = self.maximum_tokens_per_chunk else: self.tokens_per_chunk = tokens_per_chunk if self.tokens_per_chunk > self.maximum_tokens_per_chunk: msg = ( f"The token limit of the models '{self.model_name}'" f" is: {self.maximum_tokens_per_chunk}." f" Argument tokens_per_chunk={self.tokens_per_chunk}" f" > maximum token limit." ) raise ValueError(msg) def split_text(self, text: str) -> list[str]: """Splits the input text into smaller components by splitting text on tokens. This method encodes the input text using a private `_encode` method, then strips the start and stop token IDs from the encoded result. It returns the processed segments as a list of strings. Args: text: The input text to be split. Returns: A list of string components derived from the input text after encoding and processing. """ def encode_strip_start_and_stop_token_ids(text: str) -> list[int]: return self._encode(text)[1:-1] tokenizer = Tokenizer( chunk_overlap=self._chunk_overlap, tokens_per_chunk=self.tokens_per_chunk, decode=self.tokenizer.decode, encode=encode_strip_start_and_stop_token_ids, ) return split_text_on_tokens(text=text, tokenizer=tokenizer) def count_tokens(self, *, text: str) -> int: """Counts the number of tokens in the given text. This method encodes the input text using a private `_encode` method and calculates the total number of tokens in the encoded result. Args: text: The input text for which the token count is calculated. Returns: The number of tokens in the encoded text. """ return len(self._encode(text)) _max_length_equal_32_bit_integer: int = 2**32 def _encode(self, text: str) -> list[int]: token_ids_with_start_and_end_token_ids = self.tokenizer.encode( text, max_length=self._max_length_equal_32_bit_integer, truncation="do_not_truncate", ) return cast("list[int]", token_ids_with_start_and_end_token_ids)