refactor: remove sentence-transformers dependency

This commit is contained in:
Mason Daugherty 2025-07-30 20:45:04 -04:00
parent 32e5040a42
commit 0da51d0c12
No known key found for this signature in database
13 changed files with 497 additions and 9 deletions

View File

@ -4,8 +4,12 @@ from langchain_huggingface.embeddings.huggingface import (
from langchain_huggingface.embeddings.huggingface_endpoint import (
HuggingFaceEndpointEmbeddings,
)
from langchain_huggingface.embeddings.transformers_embeddings import (
TransformersEmbeddings,
)
__all__ = [
"HuggingFaceEmbeddings",
"HuggingFaceEndpointEmbeddings",
"TransformersEmbeddings",
]

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import warnings
from typing import Any, Optional
from langchain_core.embeddings import Embeddings
@ -20,6 +21,11 @@ _MIN_OPTIMUM_VERSION = "1.22"
class HuggingFaceEmbeddings(BaseModel, Embeddings):
"""HuggingFace sentence_transformers embedding models.
.. deprecated:: 0.3.1
HuggingFaceEmbeddings depends on sentence-transformers which requires
heavy dependencies (torch, pillow). Use TransformersEmbeddings instead
which provides the same functionality with lighter dependencies.
To use, you should have the ``sentence_transformers`` python package installed.
Example:
@ -64,6 +70,13 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
warnings.warn(
"HuggingFaceEmbeddings depends on sentence-transformers which requires "
"heavy dependencies (torch, pillow). Use TransformersEmbeddings instead "
"which provides the same functionality with lighter dependencies.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(**kwargs)
try:
import sentence_transformers # type: ignore[import]

View File

@ -0,0 +1,171 @@
from __future__ import annotations
from typing import Any, Optional
from langchain_core.embeddings import Embeddings
from pydantic import BaseModel, ConfigDict, Field
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
class TransformersEmbeddings(BaseModel, Embeddings):
"""HuggingFace transformers embedding models.
This replaces HuggingFaceEmbeddings by using transformers directly
instead of sentence-transformers, avoiding the pillow dependency.
To use, you should have the ``transformers`` and ``torch`` python packages
installed.
Example:
.. code-block:: python
from langchain_huggingface import TransformersEmbeddings
model_name = "sentence-transformers/all-mpnet-base-v2"
embeddings = TransformersEmbeddings(
model_name=model_name,
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
"""
model_name: str = Field(default=DEFAULT_MODEL_NAME, alias="model")
"""Model name to use."""
cache_dir: Optional[str] = None
"""Path to store models.
Can be also set by ``TRANSFORMERS_CACHE`` environment variable.
"""
model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass to the transformers model."""
encode_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass when calling the encode method."""
query_encode_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass when encoding queries."""
normalize_embeddings: bool = True
"""Whether to normalize embeddings to unit length."""
show_progress: bool = False
"""Whether to show a progress bar."""
def __init__(self, **kwargs: Any):
"""Initialize the transformers embedding model."""
super().__init__(**kwargs)
try:
import torch
from transformers import AutoModel, AutoTokenizer
except ImportError as exc:
msg = (
"Could not import transformers or torch python packages. "
"Please install them with `pip install transformers torch`."
)
raise ImportError(msg) from exc
self._tokenizer = AutoTokenizer.from_pretrained(
self.model_name, cache_dir=self.cache_dir, **self.model_kwargs
)
self._model = AutoModel.from_pretrained(
self.model_name, cache_dir=self.cache_dir, **self.model_kwargs
)
# Set model to evaluation mode
self._model.eval()
# Import torch for tensor operations
self._torch = torch
model_config = ConfigDict(
extra="forbid",
protected_namespaces=(),
populate_by_name=True,
)
def _mean_pooling(self, model_output: Any, attention_mask: Any) -> Any:
"""Apply mean pooling to get sentence embeddings."""
token_embeddings = model_output[
0
] # First element contains all token embeddings
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return self._torch.sum(
token_embeddings * input_mask_expanded, 1
) / self._torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def _embed(
self, texts: list[str], encode_kwargs: dict[str, Any]
) -> list[list[float]]:
"""Embed a list of texts using the transformers model.
Args:
texts: The list of texts to embed.
encode_kwargs: Additional keyword arguments for encoding.
Returns:
List of embeddings, one for each text.
"""
# Clean texts
texts = [x.replace("\n", " ") for x in texts]
# Tokenize texts
encoded_input = self._tokenizer(
texts, padding=True, truncation=True, return_tensors="pt", max_length=512
)
# Generate embeddings
with self._torch.no_grad():
model_output = self._model(**encoded_input)
# Apply mean pooling
sentence_embeddings = self._mean_pooling(
model_output, encoded_input["attention_mask"]
)
# Normalize embeddings if requested
if self.normalize_embeddings or encode_kwargs.get(
"normalize_embeddings", False
):
sentence_embeddings = self._torch.nn.functional.normalize(
sentence_embeddings, p=2, dim=1
)
return sentence_embeddings.cpu().numpy().tolist()
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Compute doc embeddings using a transformers model.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
return self._embed(texts, self.encode_kwargs)
def embed_query(self, text: str) -> list[float]:
"""Compute query embeddings using a transformers model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
embed_kwargs = (
self.query_encode_kwargs
if len(self.query_encode_kwargs) > 0
else self.encode_kwargs
)
return self._embed([text], embed_kwargs)[0]

View File

@ -26,6 +26,10 @@ full = [
"transformers>=4.39.0",
"sentence-transformers>=2.6.0",
]
lite = [
"transformers>=4.39.0",
"torch>=1.12.0",
]
[dependency-groups]
test = [

View File

@ -0,0 +1,76 @@
"""Unit tests for TransformersEmbeddings."""
from typing import Any
from unittest.mock import MagicMock, patch
from langchain_huggingface.embeddings.transformers_embeddings import (
TransformersEmbeddings,
)
class TestTransformersEmbeddings:
"""Test TransformersEmbeddings."""
@patch("transformers.AutoModel")
@patch("transformers.AutoTokenizer")
@patch("torch.no_grad")
def test_initialization_success(
self, mock_no_grad: Any, mock_tokenizer: Any, mock_model: Any
) -> None:
"""Test successful initialization with mocked dependencies."""
# Mock tokenizer
mock_tokenizer_instance = MagicMock()
mock_tokenizer.from_pretrained.return_value = mock_tokenizer_instance
# Mock model
mock_model_instance = MagicMock()
mock_model.from_pretrained.return_value = mock_model_instance
embeddings = TransformersEmbeddings(
model_name="sentence-transformers/all-mpnet-base-v2",
normalize_embeddings=True,
)
assert embeddings.model_name == "sentence-transformers/all-mpnet-base-v2"
assert embeddings.normalize_embeddings is True
mock_tokenizer.from_pretrained.assert_called_once()
mock_model.from_pretrained.assert_called_once()
@patch("transformers.AutoModel")
@patch("transformers.AutoTokenizer")
@patch("torch.no_grad")
def test_configuration_properties(
self, mock_no_grad: Any, mock_tokenizer: Any, mock_model: Any
) -> None:
"""Test that configuration properties are set correctly."""
# Mock tokenizer and model
mock_tokenizer.from_pretrained.return_value = MagicMock()
mock_model.from_pretrained.return_value = MagicMock()
embeddings = TransformersEmbeddings(
model_name="test-model",
cache_dir="./test_cache",
normalize_embeddings=False,
show_progress=True,
)
assert embeddings.model_name == "test-model"
assert embeddings.cache_dir == "./test_cache"
assert embeddings.normalize_embeddings is False
assert embeddings.show_progress is True
def test_model_config(self) -> None:
"""Test that model configuration is set correctly."""
# No need to initialize the actual model, just test the class attributes
config = TransformersEmbeddings.model_config
assert config["extra"] == "forbid"
assert config["populate_by_name"] is True
def test_default_values(self) -> None:
"""Test default field values without initializing."""
from langchain_huggingface.embeddings.transformers_embeddings import (
DEFAULT_MODEL_NAME,
)
# Test that default values are set correctly at class level
assert DEFAULT_MODEL_NAME == "sentence-transformers/all-mpnet-base-v2"

View File

@ -925,7 +925,7 @@ wheels = [
[[package]]
name = "langchain-core"
version = "0.3.70"
version = "0.3.72"
source = { editable = "../../core" }
dependencies = [
{ name = "jsonpatch" },
@ -996,6 +996,10 @@ full = [
{ name = "sentence-transformers" },
{ name = "transformers" },
]
lite = [
{ name = "torch" },
{ name = "transformers" },
]
[package.dev-dependencies]
codespell = [
@ -1032,9 +1036,11 @@ requires-dist = [
{ name = "langchain-core", editable = "../../core" },
{ name = "sentence-transformers", marker = "extra == 'full'", specifier = ">=2.6.0" },
{ name = "tokenizers", specifier = ">=0.19.1" },
{ name = "torch", marker = "extra == 'lite'", specifier = ">=1.12.0" },
{ name = "transformers", marker = "extra == 'full'", specifier = ">=4.39.0" },
{ name = "transformers", marker = "extra == 'lite'", specifier = ">=4.39.0" },
]
provides-extras = ["full"]
provides-extras = ["full", "lite"]
[package.metadata.requires-dev]
codespell = [{ name = "codespell", specifier = ">=2.2.0,<3.0.0" }]

View File

@ -52,6 +52,7 @@ from langchain_text_splitters.sentence_transformers import (
SentenceTransformersTokenTextSplitter,
)
from langchain_text_splitters.spacy import SpacyTextSplitter
from langchain_text_splitters.transformers_token import TransformersTokenTextSplitter
__all__ = [
"CharacterTextSplitter",
@ -77,5 +78,6 @@ __all__ = [
"TextSplitter",
"TokenTextSplitter",
"Tokenizer",
"TransformersTokenTextSplitter",
"split_text_on_tokens",
]

View File

@ -1,12 +1,19 @@
from __future__ import annotations
import warnings
from typing import Any, Optional, cast
from langchain_text_splitters.base import TextSplitter, Tokenizer, split_text_on_tokens
class SentenceTransformersTokenTextSplitter(TextSplitter):
"""Splitting text to tokens using sentence model tokenizer."""
"""Splitting text to tokens using sentence model tokenizer.
.. deprecated:: 0.3.9
SentenceTransformersTokenTextSplitter is deprecated due to heavy dependencies
(torch, pillow). Use TransformersTokenTextSplitter instead which provides
the same functionality with lighter dependencies.
"""
def __init__(
self,
@ -16,6 +23,13 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
**kwargs: Any,
) -> None:
"""Create a new TextSplitter."""
warnings.warn(
"SentenceTransformersTokenTextSplitter is deprecated due to heavy "
"dependencies (torch, pillow). Use TransformersTokenTextSplitter "
"instead which provides the same functionality with lighter dependencies.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(**kwargs, chunk_overlap=chunk_overlap)
try:

View File

@ -0,0 +1,132 @@
from __future__ import annotations
from typing import Any, Optional, cast
from langchain_text_splitters.base import TextSplitter, Tokenizer, split_text_on_tokens
class TransformersTokenTextSplitter(TextSplitter):
"""Splitting text to tokens using transformers tokenizer.
This replaces SentenceTransformersTokenTextSplitter by using the transformers
library directly instead of sentence-transformers, avoiding the heavy
dependencies of torch and pillow.
"""
def __init__(
self,
chunk_overlap: int = 50,
model_name: str = "sentence-transformers/all-mpnet-base-v2",
tokens_per_chunk: Optional[int] = None,
**kwargs: Any,
) -> None:
"""Create a new TextSplitter using transformers tokenizer.
Args:
chunk_overlap: Number of tokens to overlap between chunks.
model_name: The model name to use for tokenization.
tokens_per_chunk: Maximum number of tokens per chunk.
**kwargs: Additional arguments passed to TextSplitter.
"""
super().__init__(**kwargs, chunk_overlap=chunk_overlap)
try:
from transformers import AutoTokenizer # type: ignore[attr-defined]
except ImportError:
msg = (
"Could not import transformers python package. "
"This is needed in order to for TransformersTokenTextSplitter. "
"Please install it with `pip install transformers`."
)
raise ImportError(msg)
self.model_name = model_name
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name, clean_up_tokenization_spaces=False
)
# Set a reasonable default if no model_max_length is found
default_max_length = 512
self.maximum_tokens_per_chunk = getattr(
self.tokenizer, "model_max_length", default_max_length
)
# Handle cases where model_max_length is very large or None
if (
self.maximum_tokens_per_chunk is None
or self.maximum_tokens_per_chunk > 100000
):
self.maximum_tokens_per_chunk = default_max_length
self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk)
def _initialize_chunk_configuration(
self, *, tokens_per_chunk: Optional[int]
) -> None:
"""Initialize the chunk size configuration."""
if tokens_per_chunk is None:
self.tokens_per_chunk = self.maximum_tokens_per_chunk
else:
self.tokens_per_chunk = tokens_per_chunk
if self.tokens_per_chunk > self.maximum_tokens_per_chunk:
msg = (
f"The token limit of the models '{self.model_name}'"
f" is: {self.maximum_tokens_per_chunk}."
f" Argument tokens_per_chunk={self.tokens_per_chunk}"
f" > maximum token limit."
)
raise ValueError(msg)
def split_text(self, text: str) -> list[str]:
"""Splits the input text into smaller components by splitting text on tokens.
This method encodes the input text using the transformers tokenizer, then
strips the start and stop token IDs from the encoded result. It returns the
processed segments as a list of strings.
Args:
text: The input text to be split.
Returns:
A list of string components derived from the input text after
encoding and processing.
"""
def encode_strip_start_and_stop_token_ids(text: str) -> list[int]:
return self._encode(text)[1:-1]
tokenizer = Tokenizer(
chunk_overlap=self._chunk_overlap,
tokens_per_chunk=self.tokens_per_chunk,
decode=self.tokenizer.decode,
encode=encode_strip_start_and_stop_token_ids,
)
return split_text_on_tokens(text=text, tokenizer=tokenizer)
def count_tokens(self, *, text: str) -> int:
"""Counts the number of tokens in the given text.
This method encodes the input text using the transformers tokenizer and
calculates the total number of tokens in the encoded result.
Args:
text: The input text for which the token count is calculated.
Returns:
The number of tokens in the encoded text.
"""
return len(self._encode(text))
_max_length_equal_32_bit_integer: int = 2**32
def _encode(self, text: str) -> list[int]:
"""Encode text using the transformers tokenizer."""
token_ids_with_start_and_end_token_ids = self.tokenizer.encode(
text,
max_length=self._max_length_equal_32_bit_integer,
truncation=False,
add_special_tokens=True,
)
return cast("list[int]", token_ids_with_start_and_end_token_ids)

View File

@ -40,8 +40,9 @@ test_integration = [
"spacy<4.0.0,>=3.8.7",
"thinc<9.0.0,>=8.3.6",
"nltk<4.0.0,>=3.9.1",
"transformers<5.0.0,>=4.51.3",
"transformers>=4.51.3,<5.0.0",
"sentence-transformers>=3.0.1",
"torch>=1.12.0",
]
[tool.uv.sources]

View File

@ -0,0 +1,62 @@
"""Unit tests for TransformersTokenTextSplitter."""
import pytest
from langchain_text_splitters.transformers_token import TransformersTokenTextSplitter
class TestTransformersTokenTextSplitter:
"""Test TransformersTokenTextSplitter."""
def test_initialization(self) -> None:
"""Test that the splitter can be initialized."""
try:
splitter = TransformersTokenTextSplitter(
model_name="sentence-transformers/all-mpnet-base-v2",
chunk_overlap=10,
tokens_per_chunk=100,
)
assert splitter.model_name == "sentence-transformers/all-mpnet-base-v2"
assert splitter.tokens_per_chunk == 100
assert splitter._chunk_overlap == 10
except ImportError:
pytest.skip("transformers not available")
def test_split_text(self) -> None:
"""Test basic text splitting functionality."""
try:
splitter = TransformersTokenTextSplitter(
model_name="sentence-transformers/all-mpnet-base-v2",
tokens_per_chunk=10,
)
text = "This is a test sentence. " * 20
chunks = splitter.split_text(text)
assert isinstance(chunks, list)
assert len(chunks) > 1
assert all(isinstance(chunk, str) for chunk in chunks)
except ImportError:
pytest.skip("transformers not available")
def test_count_tokens(self) -> None:
"""Test token counting functionality."""
try:
splitter = TransformersTokenTextSplitter(
model_name="sentence-transformers/all-mpnet-base-v2"
)
text = "This is a test sentence."
token_count = splitter.count_tokens(text=text)
assert isinstance(token_count, int)
assert token_count > 0
except ImportError:
pytest.skip("transformers not available")
def test_tokens_per_chunk_validation(self) -> None:
"""Test that tokens_per_chunk is validated against model limits."""
try:
with pytest.raises(ValueError, match="maximum token limit"):
TransformersTokenTextSplitter(
model_name="sentence-transformers/all-mpnet-base-v2",
tokens_per_chunk=100000, # Way too large
)
except ImportError:
pytest.skip("transformers not available")

View File

@ -1178,6 +1178,7 @@ test-integration = [
{ name = "sentence-transformers" },
{ name = "spacy" },
{ name = "thinc" },
{ name = "torch" },
{ name = "transformers" },
]
typing = [
@ -1214,6 +1215,7 @@ test-integration = [
{ name = "sentence-transformers", specifier = ">=3.0.1" },
{ name = "spacy", specifier = ">=3.8.7,<4.0.0" },
{ name = "thinc", specifier = ">=8.3.6,<9.0.0" },
{ name = "torch", specifier = ">=1.12.0" },
{ name = "transformers", specifier = ">=4.51.3,<5.0.0" },
]
typing = [

11
uv.lock
View File

@ -181,7 +181,7 @@ wheels = [
[[package]]
name = "anthropic"
version = "0.57.1"
version = "0.60.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
@ -192,9 +192,9 @@ dependencies = [
{ name = "sniffio" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/d7/75/6261a1a8d92aed47e27d2fcfb3a411af73b1435e6ae1186da02b760565d0/anthropic-0.57.1.tar.gz", hash = "sha256:7815dd92245a70d21f65f356f33fc80c5072eada87fb49437767ea2918b2c4b0", size = 423775, upload-time = "2025-07-03T16:57:35.932Z" }
sdist = { url = "https://files.pythonhosted.org/packages/4e/03/3334921dc54ed822b3dd993ae72d823a7402588521bbba3e024b3333a1fd/anthropic-0.60.0.tar.gz", hash = "sha256:a22ba187c6f4fd5afecb2fc913b960feccf72bc0d25c1b7ce0345e87caede577", size = 425983, upload-time = "2025-07-28T19:53:47.685Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e5/cf/ca0ba77805aec6171629a8b665c7dc224dab374539c3d27005b5d8c100a0/anthropic-0.57.1-py3-none-any.whl", hash = "sha256:33afc1f395af207d07ff1bffc0a3d1caac53c371793792569c5d2f09283ea306", size = 292779, upload-time = "2025-07-03T16:57:34.636Z" },
{ url = "https://files.pythonhosted.org/packages/da/bb/d84f287fb1c217b30c328af987cf8bbe3897edf0518dcc5fa39412f794ec/anthropic-0.60.0-py3-none-any.whl", hash = "sha256:65ad1f088a960217aaf82ba91ff743d6c89e9d811c6d64275b9a7c59ee9ac3c6", size = 293116, upload-time = "2025-07-28T19:53:45.944Z" },
]
[[package]]
@ -2354,7 +2354,7 @@ typing = [
[[package]]
name = "langchain-anthropic"
version = "0.3.17"
version = "0.3.18"
source = { editable = "libs/partners/anthropic" }
dependencies = [
{ name = "anthropic" },
@ -2364,7 +2364,7 @@ dependencies = [
[package.metadata]
requires-dist = [
{ name = "anthropic", specifier = ">=0.57.0,<1" },
{ name = "anthropic", specifier = ">=0.60.0,<1" },
{ name = "langchain-core", editable = "libs/core" },
{ name = "pydantic", specifier = ">=2.7.4,<3.0.0" },
]
@ -2918,6 +2918,7 @@ test-integration = [
{ name = "sentence-transformers", specifier = ">=3.0.1" },
{ name = "spacy", specifier = ">=3.8.7,<4.0.0" },
{ name = "thinc", specifier = ">=8.3.6,<9.0.0" },
{ name = "torch", specifier = ">=1.12.0" },
{ name = "transformers", specifier = ">=4.51.3,<5.0.0" },
]
typing = [