partners: (langchain-huggingface) Embeddings - Integrate Inference Providers and remove deprecated code (#30735)

Hi there, This is a complementary PR to #30733.
This PR introduces support for Hugging Face's serverless Inference
Providers (documentation
[here](https://huggingface.co/docs/inference-providers/index)), allowing
users to specify different providers

This PR also removes the usage of `InferenceClient.post()` method in
`HuggingFaceEndpointEmbeddings`, in favor of the task-specific
`feature_extraction` method. `InferenceClient.post()` is deprecated and
will be removed in `huggingface_hub` v0.31.0.

## Changes made

- bumped the minimum required version of the `huggingface_hub` package
to ensure compatibility with the latest API usage.
- added a provider field to `HuggingFaceEndpointEmbeddings`, enabling
users to select the inference provider.
- replaced the deprecated `InferenceClient.post()` call in
`HuggingFaceEndpointEmbeddings` with the task-specific
`feature_extraction` method for future-proofing, `post()` will be
removed in `huggingface-hub` v0.31.0.

 All changes are backward compatible.

---------

Co-authored-by: Lucain <lucainp@gmail.com>
Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
célina 2025-04-09 21:05:43 +02:00 committed by GitHub
parent 98f0016fc2
commit 68361f9c2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 9 deletions

View File

@ -1,4 +1,3 @@
import json
import os
from typing import Any, List, Optional
@ -34,6 +33,10 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
async_client: Any = None #: :meta private:
model: Optional[str] = None
"""Model name to use."""
provider: Optional[str] = None
"""Name of the provider to use for inference with the model specified in
``repo_id``. e.g. "sambanova". if not specified, defaults to HF Inference API.
available providers can be found in the [huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/guides/inference#supported-providers-and-tasks)."""
repo_id: Optional[str] = None
"""Huggingfacehub repository id, for backward compatibility."""
task: Optional[str] = "feature-extraction"
@ -74,11 +77,13 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
client = InferenceClient(
model=self.model,
token=huggingfacehub_api_token,
provider=self.provider,
)
async_client = AsyncInferenceClient(
model=self.model,
token=huggingfacehub_api_token,
provider=self.provider,
)
if self.task not in VALID_TASKS:
@ -109,10 +114,8 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
texts = [text.replace("\n", " ") for text in texts]
_model_kwargs = self.model_kwargs or {}
# api doc: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/embed
responses = self.client.post(
json={"inputs": texts, **_model_kwargs}, task=self.task
)
return json.loads(responses.decode())
responses = self.client.feature_extraction(text=texts, **_model_kwargs)
return responses
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Async Call to HuggingFaceHub's embedding endpoint for embedding search docs.
@ -126,10 +129,10 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
# replace newlines, which can negatively affect performance.
texts = [text.replace("\n", " ") for text in texts]
_model_kwargs = self.model_kwargs or {}
responses = await self.async_client.post(
json={"inputs": texts, **_model_kwargs}, task=self.task
responses = await self.async_client.feature_extraction(
text=texts, **_model_kwargs
)
return json.loads(responses.decode())
return responses
def embed_query(self, text: str) -> List[float]:
"""Call out to HuggingFaceHub's embedding endpoint for embedding query text.

View File

@ -11,7 +11,7 @@ dependencies = [
"tokenizers>=0.19.1",
"transformers>=4.39.0",
"sentence-transformers>=2.6.0",
"huggingface-hub>=0.23.0",
"huggingface-hub>=0.30.2",
]
name = "langchain-huggingface"
version = "0.1.2"