mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-03 02:06:33 +00:00
Remove sentence-transformers dependency from LangChain
Co-authored-by: mdrxy <61371264+mdrxy@users.noreply.github.com>
This commit is contained in:
parent
369417c3e7
commit
e5dc9e9afd
@ -216,9 +216,9 @@ def init_embeddings(
|
||||
|
||||
return MistralAIEmbeddings(model=model_name, **kwargs)
|
||||
if provider == "huggingface":
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
from langchain_huggingface import HuggingFaceEndpointEmbeddings
|
||||
|
||||
return HuggingFaceEmbeddings(model_name=model_name, **kwargs)
|
||||
return HuggingFaceEndpointEmbeddings(model=model_name, **kwargs)
|
||||
if provider == "ollama":
|
||||
from langchain_ollama import OllamaEmbeddings
|
||||
|
||||
|
@ -21,9 +21,6 @@ from langchain_text_splitters.markdown import (
|
||||
)
|
||||
from langchain_text_splitters.nltk import NLTKTextSplitter
|
||||
from langchain_text_splitters.python import PythonCodeTextSplitter
|
||||
from langchain_text_splitters.sentence_transformers import (
|
||||
SentenceTransformersTokenTextSplitter,
|
||||
)
|
||||
from langchain_text_splitters.spacy import SpacyTextSplitter
|
||||
|
||||
__all__ = [
|
||||
@ -41,7 +38,6 @@ __all__ = [
|
||||
"PythonCodeTextSplitter",
|
||||
"RecursiveCharacterTextSplitter",
|
||||
"RecursiveJsonSplitter",
|
||||
"SentenceTransformersTokenTextSplitter",
|
||||
"SpacyTextSplitter",
|
||||
"TextSplitter",
|
||||
"TokenTextSplitter",
|
||||
|
@ -21,9 +21,6 @@ from langchain_text_splitters.markdown import (
|
||||
)
|
||||
from langchain_text_splitters.nltk import NLTKTextSplitter
|
||||
from langchain_text_splitters.python import PythonCodeTextSplitter
|
||||
from langchain_text_splitters.sentence_transformers import (
|
||||
SentenceTransformersTokenTextSplitter,
|
||||
)
|
||||
from langchain_text_splitters.spacy import SpacyTextSplitter
|
||||
|
||||
__all__ = [
|
||||
@ -41,7 +38,6 @@ __all__ = [
|
||||
"PythonCodeTextSplitter",
|
||||
"RecursiveCharacterTextSplitter",
|
||||
"RecursiveJsonSplitter",
|
||||
"SentenceTransformersTokenTextSplitter",
|
||||
"SpacyTextSplitter",
|
||||
"TextSplitter",
|
||||
"TokenTextSplitter",
|
||||
|
@ -2,7 +2,6 @@ from langchain_huggingface.chat_models import (
|
||||
ChatHuggingFace, # type: ignore[import-not-found]
|
||||
)
|
||||
from langchain_huggingface.embeddings import (
|
||||
HuggingFaceEmbeddings,
|
||||
HuggingFaceEndpointEmbeddings,
|
||||
)
|
||||
from langchain_huggingface.llms import (
|
||||
@ -12,7 +11,6 @@ from langchain_huggingface.llms import (
|
||||
|
||||
__all__ = [
|
||||
"ChatHuggingFace",
|
||||
"HuggingFaceEmbeddings",
|
||||
"HuggingFaceEndpoint",
|
||||
"HuggingFaceEndpointEmbeddings",
|
||||
"HuggingFacePipeline",
|
||||
|
@ -1,11 +1,7 @@
|
||||
from langchain_huggingface.embeddings.huggingface import (
|
||||
HuggingFaceEmbeddings, # type: ignore[import-not-found]
|
||||
)
|
||||
from langchain_huggingface.embeddings.huggingface_endpoint import (
|
||||
HuggingFaceEndpointEmbeddings,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"HuggingFaceEmbeddings",
|
||||
"HuggingFaceEndpointEmbeddings",
|
||||
]
|
||||
|
@ -1,173 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from langchain_huggingface.utils.import_utils import (
|
||||
IMPORT_ERROR,
|
||||
is_ipex_available,
|
||||
is_optimum_intel_available,
|
||||
is_optimum_intel_version,
|
||||
)
|
||||
|
||||
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
|
||||
|
||||
_MIN_OPTIMUM_VERSION = "1.22"
|
||||
|
||||
|
||||
class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
"""HuggingFace sentence_transformers embedding models.
|
||||
|
||||
To use, you should have the ``sentence_transformers`` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
|
||||
model_name = "sentence-transformers/all-mpnet-base-v2"
|
||||
model_kwargs = {'device': 'cpu'}
|
||||
encode_kwargs = {'normalize_embeddings': False}
|
||||
hf = HuggingFaceEmbeddings(
|
||||
model_name=model_name,
|
||||
model_kwargs=model_kwargs,
|
||||
encode_kwargs=encode_kwargs
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
model_name: str = Field(default=DEFAULT_MODEL_NAME, alias="model")
|
||||
"""Model name to use."""
|
||||
cache_folder: Optional[str] = None
|
||||
"""Path to store models.
|
||||
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass to the Sentence Transformer model, such as `device`,
|
||||
`prompts`, `default_prompt_name`, `revision`, `trust_remote_code`, or `token`.
|
||||
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer"""
|
||||
encode_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass when calling the `encode` method for the documents of
|
||||
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
|
||||
`precision`, `normalize_embeddings`, and more.
|
||||
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""
|
||||
query_encode_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass when calling the `encode` method for the query of
|
||||
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
|
||||
`precision`, `normalize_embeddings`, and more.
|
||||
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""
|
||||
multi_process: bool = False
|
||||
"""Run encode() on multiple GPUs."""
|
||||
show_progress: bool = False
|
||||
"""Whether to show a progress bar."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the sentence_transformer."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import sentence_transformers # type: ignore[import]
|
||||
except ImportError as exc:
|
||||
msg = (
|
||||
"Could not import sentence_transformers python package. "
|
||||
"Please install it with `pip install sentence-transformers`."
|
||||
)
|
||||
raise ImportError(msg) from exc
|
||||
|
||||
if self.model_kwargs.get("backend", "torch") == "ipex":
|
||||
if not is_optimum_intel_available() or not is_ipex_available():
|
||||
msg = f"Backend: ipex {IMPORT_ERROR.format('optimum[ipex]')}"
|
||||
raise ImportError(msg)
|
||||
|
||||
if is_optimum_intel_version("<", _MIN_OPTIMUM_VERSION):
|
||||
msg = (
|
||||
f"Backend: ipex requires optimum-intel>="
|
||||
f"{_MIN_OPTIMUM_VERSION}. You can install it with pip: "
|
||||
"`pip install --upgrade --upgrade-strategy eager "
|
||||
"`optimum[ipex]`."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
from optimum.intel import IPEXSentenceTransformer # type: ignore[import]
|
||||
|
||||
model_cls = IPEXSentenceTransformer
|
||||
|
||||
else:
|
||||
model_cls = sentence_transformers.SentenceTransformer
|
||||
|
||||
self._client = model_cls(
|
||||
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
protected_namespaces=(),
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
def _embed(
|
||||
self, texts: list[str], encode_kwargs: dict[str, Any]
|
||||
) -> list[list[float]]:
|
||||
"""Embed a text using the HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
encode_kwargs: Keyword arguments to pass when calling the
|
||||
`encode` method for the documents of the SentenceTransformer
|
||||
encode method.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
|
||||
"""
|
||||
import sentence_transformers # type: ignore[import]
|
||||
|
||||
texts = [x.replace("\n", " ") for x in texts]
|
||||
if self.multi_process:
|
||||
pool = self._client.start_multi_process_pool()
|
||||
embeddings = self._client.encode_multi_process(texts, pool)
|
||||
sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)
|
||||
else:
|
||||
embeddings = self._client.encode(
|
||||
texts,
|
||||
show_progress_bar=self.show_progress,
|
||||
**encode_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(embeddings, list):
|
||||
msg = (
|
||||
"Expected embeddings to be a Tensor or a numpy array, "
|
||||
"got a list instead."
|
||||
)
|
||||
raise TypeError(msg)
|
||||
|
||||
return embeddings.tolist() # type: ignore[return-type]
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
|
||||
"""
|
||||
return self._embed(texts, self.encode_kwargs)
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Compute query embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
|
||||
"""
|
||||
embed_kwargs = (
|
||||
self.query_encode_kwargs
|
||||
if len(self.query_encode_kwargs) > 0
|
||||
else self.encode_kwargs
|
||||
)
|
||||
return self._embed([text], embed_kwargs)[0]
|
@ -8,7 +8,7 @@ from langchain_core.utils import from_env
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
DEFAULT_MODEL = "sentence-transformers/all-mpnet-base-v2"
|
||||
DEFAULT_MODEL = "microsoft/all-mpnet-base-v2"
|
||||
VALID_TASKS = ("feature-extraction",)
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_huggingface import HuggingFaceEndpointEmbeddings
|
||||
model = "sentence-transformers/all-mpnet-base-v2"
|
||||
model = "microsoft/all-mpnet-base-v2"
|
||||
hf = HuggingFaceEndpointEmbeddings(
|
||||
model=model,
|
||||
task="feature-extraction",
|
||||
|
@ -23,8 +23,6 @@ repository = "https://github.com/langchain-ai/langchain"
|
||||
|
||||
[project.optional-dependencies]
|
||||
full = [
|
||||
"transformers>=4.39.0",
|
||||
"sentence-transformers>=2.6.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
@ -35,8 +33,6 @@ test = [
|
||||
"pytest-socket<1.0.0,>=0.7.0",
|
||||
"scipy<2,>=1; python_version < \"3.12\"",
|
||||
"scipy<2.0.0,>=1.7.0; python_version >= \"3.12\"",
|
||||
"transformers>=4.39.0",
|
||||
"sentence-transformers>=2.6.0",
|
||||
"langchain-core",
|
||||
"langchain-tests",
|
||||
"langchain-community",
|
||||
|
@ -3,21 +3,10 @@
|
||||
from langchain_tests.integration_tests import EmbeddingsIntegrationTests
|
||||
|
||||
from langchain_huggingface.embeddings import (
|
||||
HuggingFaceEmbeddings,
|
||||
HuggingFaceEndpointEmbeddings,
|
||||
)
|
||||
|
||||
|
||||
class TestHuggingFaceEmbeddings(EmbeddingsIntegrationTests):
|
||||
@property
|
||||
def embeddings_class(self) -> type[HuggingFaceEmbeddings]:
|
||||
return HuggingFaceEmbeddings
|
||||
|
||||
@property
|
||||
def embedding_model_params(self) -> dict:
|
||||
return {"model_name": "sentence-transformers/all-mpnet-base-v2"}
|
||||
|
||||
|
||||
class TestHuggingFaceEndpointEmbeddings(EmbeddingsIntegrationTests):
|
||||
@property
|
||||
def embeddings_class(self) -> type[HuggingFaceEndpointEmbeddings]:
|
||||
@ -25,4 +14,4 @@ class TestHuggingFaceEndpointEmbeddings(EmbeddingsIntegrationTests):
|
||||
|
||||
@property
|
||||
def embedding_model_params(self) -> dict:
|
||||
return {"model": "sentence-transformers/all-mpnet-base-v2"}
|
||||
return {"model": "microsoft/all-mpnet-base-v2"}
|
||||
|
@ -48,9 +48,7 @@ from langchain_text_splitters.markdown import (
|
||||
)
|
||||
from langchain_text_splitters.nltk import NLTKTextSplitter
|
||||
from langchain_text_splitters.python import PythonCodeTextSplitter
|
||||
from langchain_text_splitters.sentence_transformers import (
|
||||
SentenceTransformersTokenTextSplitter,
|
||||
)
|
||||
|
||||
from langchain_text_splitters.spacy import SpacyTextSplitter
|
||||
|
||||
__all__ = [
|
||||
@ -72,7 +70,6 @@ __all__ = [
|
||||
"PythonCodeTextSplitter",
|
||||
"RecursiveCharacterTextSplitter",
|
||||
"RecursiveJsonSplitter",
|
||||
"SentenceTransformersTokenTextSplitter",
|
||||
"SpacyTextSplitter",
|
||||
"TextSplitter",
|
||||
"TokenTextSplitter",
|
||||
|
@ -1,104 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from langchain_text_splitters.base import TextSplitter, Tokenizer, split_text_on_tokens
|
||||
|
||||
|
||||
class SentenceTransformersTokenTextSplitter(TextSplitter):
|
||||
"""Splitting text to tokens using sentence model tokenizer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_overlap: int = 50,
|
||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
tokens_per_chunk: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs, chunk_overlap=chunk_overlap)
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ImportError:
|
||||
msg = (
|
||||
"Could not import sentence_transformers python package. "
|
||||
"This is needed in order to for SentenceTransformersTokenTextSplitter. "
|
||||
"Please install it with `pip install sentence-transformers`."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
self.model_name = model_name
|
||||
self._model = SentenceTransformer(self.model_name)
|
||||
self.tokenizer = self._model.tokenizer
|
||||
self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk)
|
||||
|
||||
def _initialize_chunk_configuration(
|
||||
self, *, tokens_per_chunk: Optional[int]
|
||||
) -> None:
|
||||
self.maximum_tokens_per_chunk = self._model.max_seq_length
|
||||
|
||||
if tokens_per_chunk is None:
|
||||
self.tokens_per_chunk = self.maximum_tokens_per_chunk
|
||||
else:
|
||||
self.tokens_per_chunk = tokens_per_chunk
|
||||
|
||||
if self.tokens_per_chunk > self.maximum_tokens_per_chunk:
|
||||
msg = (
|
||||
f"The token limit of the models '{self.model_name}'"
|
||||
f" is: {self.maximum_tokens_per_chunk}."
|
||||
f" Argument tokens_per_chunk={self.tokens_per_chunk}"
|
||||
f" > maximum token limit."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
"""Splits the input text into smaller components by splitting text on tokens.
|
||||
|
||||
This method encodes the input text using a private `_encode` method, then
|
||||
strips the start and stop token IDs from the encoded result. It returns the
|
||||
processed segments as a list of strings.
|
||||
|
||||
Args:
|
||||
text (str): The input text to be split.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of string components derived from the input text after
|
||||
encoding and processing.
|
||||
"""
|
||||
|
||||
def encode_strip_start_and_stop_token_ids(text: str) -> list[int]:
|
||||
return self._encode(text)[1:-1]
|
||||
|
||||
tokenizer = Tokenizer(
|
||||
chunk_overlap=self._chunk_overlap,
|
||||
tokens_per_chunk=self.tokens_per_chunk,
|
||||
decode=self.tokenizer.decode,
|
||||
encode=encode_strip_start_and_stop_token_ids,
|
||||
)
|
||||
|
||||
return split_text_on_tokens(text=text, tokenizer=tokenizer)
|
||||
|
||||
def count_tokens(self, *, text: str) -> int:
|
||||
"""Counts the number of tokens in the given text.
|
||||
|
||||
This method encodes the input text using a private `_encode` method and
|
||||
calculates the total number of tokens in the encoded result.
|
||||
|
||||
Args:
|
||||
text (str): The input text for which the token count is calculated.
|
||||
|
||||
Returns:
|
||||
int: The number of tokens in the encoded text.
|
||||
"""
|
||||
return len(self._encode(text))
|
||||
|
||||
_max_length_equal_32_bit_integer: int = 2**32
|
||||
|
||||
def _encode(self, text: str) -> list[int]:
|
||||
token_ids_with_start_and_end_token_ids = self.tokenizer.encode(
|
||||
text,
|
||||
max_length=self._max_length_equal_32_bit_integer,
|
||||
truncation="do_not_truncate",
|
||||
)
|
||||
return cast("list[int]", token_ids_with_start_and_end_token_ids)
|
@ -41,7 +41,6 @@ test_integration = [
|
||||
"thinc<9.0.0,>=8.3.6",
|
||||
"nltk<4.0.0,>=3.9.1",
|
||||
"transformers<5.0.0,>=4.51.3",
|
||||
"sentence-transformers>=3.0.1",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
|
@ -8,18 +8,8 @@ from langchain_text_splitters import (
|
||||
TokenTextSplitter,
|
||||
)
|
||||
from langchain_text_splitters.character import CharacterTextSplitter
|
||||
from langchain_text_splitters.sentence_transformers import (
|
||||
SentenceTransformersTokenTextSplitter,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sentence_transformers() -> Any:
|
||||
try:
|
||||
import sentence_transformers
|
||||
except ImportError:
|
||||
pytest.skip("SentenceTransformers not installed.")
|
||||
return sentence_transformers
|
||||
|
||||
|
||||
def test_huggingface_type_check() -> None:
|
||||
@ -63,58 +53,3 @@ def test_token_text_splitter_from_tiktoken() -> None:
|
||||
assert expected_tokenizer == actual_tokenizer
|
||||
|
||||
|
||||
def test_sentence_transformers_count_tokens(sentence_transformers: Any) -> None:
|
||||
splitter = SentenceTransformersTokenTextSplitter(
|
||||
model_name="sentence-transformers/paraphrase-albert-small-v2"
|
||||
)
|
||||
text = "Lorem ipsum"
|
||||
|
||||
token_count = splitter.count_tokens(text=text)
|
||||
|
||||
expected_start_stop_token_count = 2
|
||||
expected_text_token_count = 5
|
||||
expected_token_count = expected_start_stop_token_count + expected_text_token_count
|
||||
|
||||
assert expected_token_count == token_count
|
||||
|
||||
|
||||
def test_sentence_transformers_split_text(sentence_transformers: Any) -> None:
|
||||
splitter = SentenceTransformersTokenTextSplitter(
|
||||
model_name="sentence-transformers/paraphrase-albert-small-v2"
|
||||
)
|
||||
text = "lorem ipsum"
|
||||
text_chunks = splitter.split_text(text=text)
|
||||
expected_text_chunks = [text]
|
||||
assert expected_text_chunks == text_chunks
|
||||
|
||||
|
||||
def test_sentence_transformers_multiple_tokens(sentence_transformers: Any) -> None:
|
||||
splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0)
|
||||
text = "Lorem "
|
||||
|
||||
text_token_count_including_start_and_stop_tokens = splitter.count_tokens(text=text)
|
||||
count_start_and_end_tokens = 2
|
||||
token_multiplier = (
|
||||
count_start_and_end_tokens
|
||||
+ (splitter.maximum_tokens_per_chunk - count_start_and_end_tokens)
|
||||
// (
|
||||
text_token_count_including_start_and_stop_tokens
|
||||
- count_start_and_end_tokens
|
||||
)
|
||||
+ 1
|
||||
)
|
||||
|
||||
# `text_to_split` does not fit in a single chunk
|
||||
text_to_embed = text * token_multiplier
|
||||
|
||||
text_chunks = splitter.split_text(text=text_to_embed)
|
||||
|
||||
expected_number_of_chunks = 2
|
||||
|
||||
assert expected_number_of_chunks == len(text_chunks)
|
||||
actual = splitter.count_tokens(text=text_chunks[1]) - count_start_and_end_tokens
|
||||
expected = (
|
||||
token_multiplier * (text_token_count_including_start_and_stop_tokens - 2)
|
||||
- splitter.maximum_tokens_per_chunk
|
||||
)
|
||||
assert expected == actual
|
||||
|
Loading…
Reference in New Issue
Block a user