diff --git a/libs/text-splitters/langchain_text_splitters/sentence_transformers.py b/libs/text-splitters/langchain_text_splitters/sentence_transformers.py index 03bde0a60c5..a80866c2ac1 100644 --- a/libs/text-splitters/langchain_text_splitters/sentence_transformers.py +++ b/libs/text-splitters/langchain_text_splitters/sentence_transformers.py @@ -25,6 +25,7 @@ class SentenceTransformersTokenTextSplitter(TextSplitter): chunk_overlap: int = 50, model_name: str = "sentence-transformers/all-mpnet-base-v2", tokens_per_chunk: int | None = None, + model_kwargs: dict[str, Any] | None = None, **kwargs: Any, ) -> None: """Create a new `TextSplitter`. @@ -35,6 +36,8 @@ class SentenceTransformersTokenTextSplitter(TextSplitter): tokens_per_chunk: The number of tokens per chunk. If `None`, uses the maximum tokens allowed by the model. + model_kwargs: Additional parameters for model initialization. + Parameters of sentence_transformers.SentenceTransformer can be used. Raises: ImportError: If the `sentence_transformers` package is not installed. @@ -50,7 +53,7 @@ class SentenceTransformersTokenTextSplitter(TextSplitter): raise ImportError(msg) self.model_name = model_name - self._model = SentenceTransformer(self.model_name) + self._model = SentenceTransformer(self.model_name, **(model_kwargs or {})) self.tokenizer = self._model.tokenizer self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk) diff --git a/libs/text-splitters/tests/integration_tests/test_text_splitter.py b/libs/text-splitters/tests/integration_tests/test_text_splitter.py index 52941064e34..c5de052f4b7 100644 --- a/libs/text-splitters/tests/integration_tests/test_text_splitter.py +++ b/libs/text-splitters/tests/integration_tests/test_text_splitter.py @@ -112,3 +112,22 @@ def test_sentence_transformers_multiple_tokens() -> None: - splitter.maximum_tokens_per_chunk ) assert expected == actual + + +@pytest.mark.requires("sentence_transformers") +def test_sentence_transformers_with_additional_model_kwargs() -> None: + """Test passing model_kwargs to SentenceTransformer.""" + # ensure model is downloaded (online) + splitter_online = SentenceTransformersTokenTextSplitter( + model_name="sentence-transformers/paraphrase-albert-small-v2" + ) + text = "lorem ipsum" + splitter_online.count_tokens(text=text) + + # test offline model loading using model_kwargs + splitter_offline = SentenceTransformersTokenTextSplitter( + model_name="sentence-transformers/paraphrase-albert-small-v2", + model_kwargs={"local_files_only": True}, + ) + splitter_offline.count_tokens(text=text) + assert splitter_offline.tokenizer is not None