mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
feat(text-splitters): add model_kwargs to SentenceTransformersTokenTextSplitter (#35113)
This commit is contained in:
@@ -25,6 +25,7 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
|
||||
chunk_overlap: int = 50,
|
||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
tokens_per_chunk: int | None = None,
|
||||
model_kwargs: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a new `TextSplitter`.
|
||||
@@ -35,6 +36,8 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
|
||||
tokens_per_chunk: The number of tokens per chunk.
|
||||
|
||||
If `None`, uses the maximum tokens allowed by the model.
|
||||
model_kwargs: Additional parameters for model initialization.
|
||||
Parameters of sentence_transformers.SentenceTransformer can be used.
|
||||
|
||||
Raises:
|
||||
ImportError: If the `sentence_transformers` package is not installed.
|
||||
@@ -50,7 +53,7 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
|
||||
raise ImportError(msg)
|
||||
|
||||
self.model_name = model_name
|
||||
self._model = SentenceTransformer(self.model_name)
|
||||
self._model = SentenceTransformer(self.model_name, **(model_kwargs or {}))
|
||||
self.tokenizer = self._model.tokenizer
|
||||
self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk)
|
||||
|
||||
|
||||
@@ -112,3 +112,22 @@ def test_sentence_transformers_multiple_tokens() -> None:
|
||||
- splitter.maximum_tokens_per_chunk
|
||||
)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
@pytest.mark.requires("sentence_transformers")
|
||||
def test_sentence_transformers_with_additional_model_kwargs() -> None:
|
||||
"""Test passing model_kwargs to SentenceTransformer."""
|
||||
# ensure model is downloaded (online)
|
||||
splitter_online = SentenceTransformersTokenTextSplitter(
|
||||
model_name="sentence-transformers/paraphrase-albert-small-v2"
|
||||
)
|
||||
text = "lorem ipsum"
|
||||
splitter_online.count_tokens(text=text)
|
||||
|
||||
# test offline model loading using model_kwargs
|
||||
splitter_offline = SentenceTransformersTokenTextSplitter(
|
||||
model_name="sentence-transformers/paraphrase-albert-small-v2",
|
||||
model_kwargs={"local_files_only": True},
|
||||
)
|
||||
splitter_offline.count_tokens(text=text)
|
||||
assert splitter_offline.tokenizer is not None
|
||||
|
||||
Reference in New Issue
Block a user