mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-15 23:57:21 +00:00
Apply patch [skip ci]
This commit is contained in:
parent
32e5040a42
commit
2f8756d5bd
@ -19,24 +19,32 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
|
||||
super().__init__(**kwargs, chunk_overlap=chunk_overlap)
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
except ImportError:
|
||||
msg = (
|
||||
"Could not import sentence_transformers python package. "
|
||||
"Could not import transformers python package. "
|
||||
"This is needed in order to for SentenceTransformersTokenTextSplitter. "
|
||||
"Please install it with `pip install sentence-transformers`."
|
||||
"Please install it with `pip install transformers`."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
self.model_name = model_name
|
||||
self._model = SentenceTransformer(self.model_name)
|
||||
self.tokenizer = self._model.tokenizer
|
||||
# Load tokenizer and config from transformers
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
self._config = AutoConfig.from_pretrained(self.model_name)
|
||||
self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk)
|
||||
|
||||
def _initialize_chunk_configuration(
|
||||
self, *, tokens_per_chunk: Optional[int]
|
||||
) -> None:
|
||||
self.maximum_tokens_per_chunk = self._model.max_seq_length
|
||||
# Get max_seq_length from config, fallback to max_position_embeddings
|
||||
if hasattr(self._config, "max_seq_length"):
|
||||
self.maximum_tokens_per_chunk = self._config.max_seq_length
|
||||
elif hasattr(self._config, "max_position_embeddings"):
|
||||
self.maximum_tokens_per_chunk = self._config.max_position_embeddings
|
||||
else:
|
||||
# Default fallback for models without explicit max length
|
||||
self.maximum_tokens_per_chunk = 512
|
||||
|
||||
if tokens_per_chunk is None:
|
||||
self.tokens_per_chunk = self.maximum_tokens_per_chunk
|
||||
@ -102,3 +110,4 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
|
||||
truncation="do_not_truncate",
|
||||
)
|
||||
return cast("list[int]", token_ids_with_start_and_end_token_ids)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user