Apply patch [skip ci]

This commit is contained in:
open-swe[bot] 2025-07-31 00:10:16 +00:00
parent 32e5040a42
commit 2f8756d5bd

View File

@ -19,24 +19,32 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
super().__init__(**kwargs, chunk_overlap=chunk_overlap) super().__init__(**kwargs, chunk_overlap=chunk_overlap)
try: try:
from sentence_transformers import SentenceTransformer from transformers import AutoConfig, AutoTokenizer
except ImportError: except ImportError:
msg = ( msg = (
"Could not import sentence_transformers python package. " "Could not import transformers python package. "
"This is needed in order to for SentenceTransformersTokenTextSplitter. " "This is needed in order to for SentenceTransformersTokenTextSplitter. "
"Please install it with `pip install sentence-transformers`." "Please install it with `pip install transformers`."
) )
raise ImportError(msg) raise ImportError(msg)
self.model_name = model_name self.model_name = model_name
self._model = SentenceTransformer(self.model_name) # Load tokenizer and config from transformers
self.tokenizer = self._model.tokenizer self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self._config = AutoConfig.from_pretrained(self.model_name)
self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk) self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk)
def _initialize_chunk_configuration( def _initialize_chunk_configuration(
self, *, tokens_per_chunk: Optional[int] self, *, tokens_per_chunk: Optional[int]
) -> None: ) -> None:
self.maximum_tokens_per_chunk = self._model.max_seq_length # Get max_seq_length from config, fallback to max_position_embeddings
if hasattr(self._config, "max_seq_length"):
self.maximum_tokens_per_chunk = self._config.max_seq_length
elif hasattr(self._config, "max_position_embeddings"):
self.maximum_tokens_per_chunk = self._config.max_position_embeddings
else:
# Default fallback for models without explicit max length
self.maximum_tokens_per_chunk = 512
if tokens_per_chunk is None: if tokens_per_chunk is None:
self.tokens_per_chunk = self.maximum_tokens_per_chunk self.tokens_per_chunk = self.maximum_tokens_per_chunk
@ -102,3 +110,4 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
truncation="do_not_truncate", truncation="do_not_truncate",
) )
return cast("list[int]", token_ids_with_start_and_end_token_ids) return cast("list[int]", token_ids_with_start_and_end_token_ids)