Apply patch [skip ci]

This commit is contained in:
open-swe[bot] 2025-07-31 00:10:16 +00:00
parent 32e5040a42
commit 2f8756d5bd

View File

@ -19,24 +19,32 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
super().__init__(**kwargs, chunk_overlap=chunk_overlap)
try:
from sentence_transformers import SentenceTransformer
from transformers import AutoConfig, AutoTokenizer
except ImportError:
msg = (
"Could not import sentence_transformers python package. "
"Could not import transformers python package. "
"This is needed in order to for SentenceTransformersTokenTextSplitter. "
"Please install it with `pip install sentence-transformers`."
"Please install it with `pip install transformers`."
)
raise ImportError(msg)
self.model_name = model_name
self._model = SentenceTransformer(self.model_name)
self.tokenizer = self._model.tokenizer
# Load tokenizer and config from transformers
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self._config = AutoConfig.from_pretrained(self.model_name)
self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk)
def _initialize_chunk_configuration(
self, *, tokens_per_chunk: Optional[int]
) -> None:
self.maximum_tokens_per_chunk = self._model.max_seq_length
# Get max_seq_length from config, fallback to max_position_embeddings
if hasattr(self._config, "max_seq_length"):
self.maximum_tokens_per_chunk = self._config.max_seq_length
elif hasattr(self._config, "max_position_embeddings"):
self.maximum_tokens_per_chunk = self._config.max_position_embeddings
else:
# Default fallback for models without explicit max length
self.maximum_tokens_per_chunk = 512
if tokens_per_chunk is None:
self.tokens_per_chunk = self.maximum_tokens_per_chunk
@ -102,3 +110,4 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
truncation="do_not_truncate",
)
return cast("list[int]", token_ids_with_start_and_end_token_ids)