mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-17 00:17:47 +00:00
Apply patch [skip ci]
This commit is contained in:
parent
32e5040a42
commit
2f8756d5bd
@ -19,24 +19,32 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
|
|||||||
super().__init__(**kwargs, chunk_overlap=chunk_overlap)
|
super().__init__(**kwargs, chunk_overlap=chunk_overlap)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from sentence_transformers import SentenceTransformer
|
from transformers import AutoConfig, AutoTokenizer
|
||||||
except ImportError:
|
except ImportError:
|
||||||
msg = (
|
msg = (
|
||||||
"Could not import sentence_transformers python package. "
|
"Could not import transformers python package. "
|
||||||
"This is needed in order to for SentenceTransformersTokenTextSplitter. "
|
"This is needed in order to for SentenceTransformersTokenTextSplitter. "
|
||||||
"Please install it with `pip install sentence-transformers`."
|
"Please install it with `pip install transformers`."
|
||||||
)
|
)
|
||||||
raise ImportError(msg)
|
raise ImportError(msg)
|
||||||
|
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self._model = SentenceTransformer(self.model_name)
|
# Load tokenizer and config from transformers
|
||||||
self.tokenizer = self._model.tokenizer
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||||
|
self._config = AutoConfig.from_pretrained(self.model_name)
|
||||||
self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk)
|
self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk)
|
||||||
|
|
||||||
def _initialize_chunk_configuration(
|
def _initialize_chunk_configuration(
|
||||||
self, *, tokens_per_chunk: Optional[int]
|
self, *, tokens_per_chunk: Optional[int]
|
||||||
) -> None:
|
) -> None:
|
||||||
self.maximum_tokens_per_chunk = self._model.max_seq_length
|
# Get max_seq_length from config, fallback to max_position_embeddings
|
||||||
|
if hasattr(self._config, "max_seq_length"):
|
||||||
|
self.maximum_tokens_per_chunk = self._config.max_seq_length
|
||||||
|
elif hasattr(self._config, "max_position_embeddings"):
|
||||||
|
self.maximum_tokens_per_chunk = self._config.max_position_embeddings
|
||||||
|
else:
|
||||||
|
# Default fallback for models without explicit max length
|
||||||
|
self.maximum_tokens_per_chunk = 512
|
||||||
|
|
||||||
if tokens_per_chunk is None:
|
if tokens_per_chunk is None:
|
||||||
self.tokens_per_chunk = self.maximum_tokens_per_chunk
|
self.tokens_per_chunk = self.maximum_tokens_per_chunk
|
||||||
@ -102,3 +110,4 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
|
|||||||
truncation="do_not_truncate",
|
truncation="do_not_truncate",
|
||||||
)
|
)
|
||||||
return cast("list[int]", token_ids_with_start_and_end_token_ids)
|
return cast("list[int]", token_ids_with_start_and_end_token_ids)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user