mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-08 22:42:05 +00:00
huggingface[patch]: ruff fixes and rules (#31912)
* bump ruff deps * add more thorough ruff rules * fix said rules
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from ..utils.import_utils import (
|
||||
from langchain_huggingface.utils.import_utils import (
|
||||
IMPORT_ERROR,
|
||||
is_ipex_available,
|
||||
is_optimum_intel_available,
|
||||
@@ -33,12 +35,13 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
model_kwargs=model_kwargs,
|
||||
encode_kwargs=encode_kwargs
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
model_name: str = Field(default=DEFAULT_MODEL_NAME, alias="model")
|
||||
"""Model name to use."""
|
||||
cache_folder: Optional[str] = None
|
||||
"""Path to store models.
|
||||
"""Path to store models.
|
||||
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass to the Sentence Transformer model, such as `device`,
|
||||
@@ -46,12 +49,12 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer"""
|
||||
encode_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass when calling the `encode` method for the documents of
|
||||
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
|
||||
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
|
||||
`precision`, `normalize_embeddings`, and more.
|
||||
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""
|
||||
query_encode_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass when calling the `encode` method for the query of
|
||||
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
|
||||
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
|
||||
`precision`, `normalize_embeddings`, and more.
|
||||
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""
|
||||
multi_process: bool = False
|
||||
@@ -65,24 +68,25 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
try:
|
||||
import sentence_transformers # type: ignore[import]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
msg = (
|
||||
"Could not import sentence_transformers python package. "
|
||||
"Please install it with `pip install sentence-transformers`."
|
||||
) from exc
|
||||
)
|
||||
raise ImportError(msg) from exc
|
||||
|
||||
if self.model_kwargs.get("backend", "torch") == "ipex":
|
||||
if not is_optimum_intel_available() or not is_ipex_available():
|
||||
raise ImportError(
|
||||
f'Backend: ipex {IMPORT_ERROR.format("optimum[ipex]")}'
|
||||
)
|
||||
msg = f'Backend: ipex {IMPORT_ERROR.format("optimum[ipex]")}'
|
||||
raise ImportError(msg)
|
||||
|
||||
if is_optimum_intel_version("<", _MIN_OPTIMUM_VERSION):
|
||||
raise ImportError(
|
||||
msg = (
|
||||
f"Backend: ipex requires optimum-intel>="
|
||||
f"{_MIN_OPTIMUM_VERSION}. You can install it with pip: "
|
||||
"`pip install --upgrade --upgrade-strategy eager "
|
||||
"`optimum[ipex]`."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
from optimum.intel import IPEXSentenceTransformer # type: ignore[import]
|
||||
|
||||
@@ -104,21 +108,21 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
def _embed(
|
||||
self, texts: list[str], encode_kwargs: dict[str, Any]
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Embed a text using the HuggingFace transformer model.
|
||||
"""Embed a text using the HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
encode_kwargs: Keyword arguments to pass when calling the
|
||||
`encode` method for the documents of the SentenceTransformer
|
||||
encode method.
|
||||
encode method.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
|
||||
"""
|
||||
import sentence_transformers # type: ignore[import]
|
||||
|
||||
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
||||
texts = [x.replace("\n", " ") for x in texts]
|
||||
if self.multi_process:
|
||||
pool = self._client.start_multi_process_pool()
|
||||
embeddings = self._client.encode_multi_process(texts, pool)
|
||||
@@ -127,16 +131,17 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
embeddings = self._client.encode(
|
||||
texts,
|
||||
show_progress_bar=self.show_progress,
|
||||
**encode_kwargs, # type: ignore
|
||||
**encode_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(embeddings, list):
|
||||
raise TypeError(
|
||||
msg = (
|
||||
"Expected embeddings to be a Tensor or a numpy array, "
|
||||
"got a list instead."
|
||||
)
|
||||
raise TypeError(msg)
|
||||
|
||||
return embeddings.tolist()
|
||||
return embeddings.tolist() # type: ignore[return-type]
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace transformer model.
|
||||
@@ -146,6 +151,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
|
||||
"""
|
||||
return self._embed(texts, self.encode_kwargs)
|
||||
|
||||
@@ -157,6 +163,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
|
||||
"""
|
||||
embed_kwargs = (
|
||||
self.query_encode_kwargs
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -27,6 +29,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
task="feature-extraction",
|
||||
huggingfacehub_api_token="my-api-key",
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
client: Any = None #: :meta private:
|
||||
@@ -35,7 +38,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
"""Model name to use."""
|
||||
provider: Optional[str] = None
|
||||
"""Name of the provider to use for inference with the model specified in
|
||||
``repo_id``. e.g. "sambanova". if not specified, defaults to HF Inference API.
|
||||
``repo_id``. e.g. "sambanova". if not specified, defaults to HF Inference API.
|
||||
available providers can be found in the [huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/guides/inference#supported-providers-and-tasks)."""
|
||||
repo_id: Optional[str] = None
|
||||
"""Huggingfacehub repository id, for backward compatibility."""
|
||||
@@ -87,18 +90,20 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
|
||||
if self.task not in VALID_TASKS:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Got invalid task {self.task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
self.client = client
|
||||
self.async_client = async_client
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
msg = (
|
||||
"Could not import huggingface_hub python package. "
|
||||
"Please install it with `pip install huggingface_hub`."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
return self
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
@@ -109,6 +114,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
|
||||
"""
|
||||
# replace newlines, which can negatively affect performance.
|
||||
texts = [text.replace("\n", " ") for text in texts]
|
||||
@@ -125,6 +131,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
|
||||
"""
|
||||
# replace newlines, which can negatively affect performance.
|
||||
texts = [text.replace("\n", " ") for text in texts]
|
||||
@@ -142,9 +149,9 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
|
||||
"""
|
||||
response = self.embed_documents([text])[0]
|
||||
return response
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
"""Async Call to HuggingFaceHub's embedding endpoint for embedding query text.
|
||||
@@ -154,6 +161,6 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
|
||||
"""
|
||||
response = (await self.aembed_documents([text]))[0]
|
||||
return response
|
||||
return (await self.aembed_documents([text]))[0]
|
||||
|
Reference in New Issue
Block a user