mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-07 22:11:51 +00:00
huggingface[patch]: ruff fixes and rules (#31912)
* bump ruff deps * add more thorough ruff rules * fix said rules
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from ..utils.import_utils import (
|
||||
from langchain_huggingface.utils.import_utils import (
|
||||
IMPORT_ERROR,
|
||||
is_ipex_available,
|
||||
is_optimum_intel_available,
|
||||
@@ -33,12 +35,13 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
model_kwargs=model_kwargs,
|
||||
encode_kwargs=encode_kwargs
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
model_name: str = Field(default=DEFAULT_MODEL_NAME, alias="model")
|
||||
"""Model name to use."""
|
||||
cache_folder: Optional[str] = None
|
||||
"""Path to store models.
|
||||
"""Path to store models.
|
||||
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass to the Sentence Transformer model, such as `device`,
|
||||
@@ -46,12 +49,12 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer"""
|
||||
encode_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass when calling the `encode` method for the documents of
|
||||
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
|
||||
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
|
||||
`precision`, `normalize_embeddings`, and more.
|
||||
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""
|
||||
query_encode_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass when calling the `encode` method for the query of
|
||||
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
|
||||
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
|
||||
`precision`, `normalize_embeddings`, and more.
|
||||
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""
|
||||
multi_process: bool = False
|
||||
@@ -65,24 +68,25 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
try:
|
||||
import sentence_transformers # type: ignore[import]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
msg = (
|
||||
"Could not import sentence_transformers python package. "
|
||||
"Please install it with `pip install sentence-transformers`."
|
||||
) from exc
|
||||
)
|
||||
raise ImportError(msg) from exc
|
||||
|
||||
if self.model_kwargs.get("backend", "torch") == "ipex":
|
||||
if not is_optimum_intel_available() or not is_ipex_available():
|
||||
raise ImportError(
|
||||
f'Backend: ipex {IMPORT_ERROR.format("optimum[ipex]")}'
|
||||
)
|
||||
msg = f'Backend: ipex {IMPORT_ERROR.format("optimum[ipex]")}'
|
||||
raise ImportError(msg)
|
||||
|
||||
if is_optimum_intel_version("<", _MIN_OPTIMUM_VERSION):
|
||||
raise ImportError(
|
||||
msg = (
|
||||
f"Backend: ipex requires optimum-intel>="
|
||||
f"{_MIN_OPTIMUM_VERSION}. You can install it with pip: "
|
||||
"`pip install --upgrade --upgrade-strategy eager "
|
||||
"`optimum[ipex]`."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
from optimum.intel import IPEXSentenceTransformer # type: ignore[import]
|
||||
|
||||
@@ -104,21 +108,21 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
def _embed(
|
||||
self, texts: list[str], encode_kwargs: dict[str, Any]
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Embed a text using the HuggingFace transformer model.
|
||||
"""Embed a text using the HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
encode_kwargs: Keyword arguments to pass when calling the
|
||||
`encode` method for the documents of the SentenceTransformer
|
||||
encode method.
|
||||
encode method.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
|
||||
"""
|
||||
import sentence_transformers # type: ignore[import]
|
||||
|
||||
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
||||
texts = [x.replace("\n", " ") for x in texts]
|
||||
if self.multi_process:
|
||||
pool = self._client.start_multi_process_pool()
|
||||
embeddings = self._client.encode_multi_process(texts, pool)
|
||||
@@ -127,16 +131,17 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
embeddings = self._client.encode(
|
||||
texts,
|
||||
show_progress_bar=self.show_progress,
|
||||
**encode_kwargs, # type: ignore
|
||||
**encode_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(embeddings, list):
|
||||
raise TypeError(
|
||||
msg = (
|
||||
"Expected embeddings to be a Tensor or a numpy array, "
|
||||
"got a list instead."
|
||||
)
|
||||
raise TypeError(msg)
|
||||
|
||||
return embeddings.tolist()
|
||||
return embeddings.tolist() # type: ignore[return-type]
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace transformer model.
|
||||
@@ -146,6 +151,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
|
||||
"""
|
||||
return self._embed(texts, self.encode_kwargs)
|
||||
|
||||
@@ -157,6 +163,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
|
||||
"""
|
||||
embed_kwargs = (
|
||||
self.query_encode_kwargs
|
||||
|
Reference in New Issue
Block a user