feat(text-splitters): add model_kwargs to SentenceTransformersTokenTextSplitter (#35113)

This commit is contained in:
Katha
2026-02-11 18:26:58 +01:00
committed by GitHub
parent 5e8a2c5309
commit 253398ebca
2 changed files with 23 additions and 1 deletions

View File

@@ -25,6 +25,7 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
chunk_overlap: int = 50,
model_name: str = "sentence-transformers/all-mpnet-base-v2",
tokens_per_chunk: int | None = None,
model_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
"""Create a new `TextSplitter`.
@@ -35,6 +36,8 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
tokens_per_chunk: The number of tokens per chunk.
If `None`, uses the maximum tokens allowed by the model.
model_kwargs: Additional parameters for model initialization.
Parameters of sentence_transformers.SentenceTransformer can be used.
Raises:
ImportError: If the `sentence_transformers` package is not installed.
@@ -50,7 +53,7 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
raise ImportError(msg)
self.model_name = model_name
self._model = SentenceTransformer(self.model_name)
self._model = SentenceTransformer(self.model_name, **(model_kwargs or {}))
self.tokenizer = self._model.tokenizer
self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk)

View File

@@ -112,3 +112,22 @@ def test_sentence_transformers_multiple_tokens() -> None:
- splitter.maximum_tokens_per_chunk
)
assert expected == actual
@pytest.mark.requires("sentence_transformers")
def test_sentence_transformers_with_additional_model_kwargs() -> None:
"""Test passing model_kwargs to SentenceTransformer."""
# ensure model is downloaded (online)
splitter_online = SentenceTransformersTokenTextSplitter(
model_name="sentence-transformers/paraphrase-albert-small-v2"
)
text = "lorem ipsum"
splitter_online.count_tokens(text=text)
# test offline model loading using model_kwargs
splitter_offline = SentenceTransformersTokenTextSplitter(
model_name="sentence-transformers/paraphrase-albert-small-v2",
model_kwargs={"local_files_only": True},
)
splitter_offline.count_tokens(text=text)
assert splitter_offline.tokenizer is not None