From 68361f9c2d44247e67c3f60591be346f2dfb1729 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?c=C3=A9lina?= Date: Wed, 9 Apr 2025 21:05:43 +0200 Subject: [PATCH] partners: (langchain-huggingface) Embeddings - Integrate Inference Providers and remove deprecated code (#30735) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Hi there, This is a complementary PR to #30733. This PR introduces support for Hugging Face's serverless Inference Providers (documentation [here](https://huggingface.co/docs/inference-providers/index)), allowing users to specify different providers This PR also removes the usage of `InferenceClient.post()` method in `HuggingFaceEndpointEmbeddings`, in favor of the task-specific `feature_extraction` method. `InferenceClient.post()` is deprecated and will be removed in `huggingface_hub` v0.31.0. ## Changes made - bumped the minimum required version of the `huggingface_hub` package to ensure compatibility with the latest API usage. - added a provider field to `HuggingFaceEndpointEmbeddings`, enabling users to select the inference provider. - replaced the deprecated `InferenceClient.post()` call in `HuggingFaceEndpointEmbeddings` with the task-specific `feature_extraction` method for future-proofing, `post()` will be removed in `huggingface-hub` v0.31.0. ✅ All changes are backward compatible. --------- Co-authored-by: Lucain Co-authored-by: ccurme --- .../embeddings/huggingface_endpoint.py | 19 +++++++++++-------- libs/partners/huggingface/pyproject.toml | 2 +- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface_endpoint.py b/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface_endpoint.py index baef1967ee1..896e0fde3b8 100644 --- a/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface_endpoint.py +++ b/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface_endpoint.py @@ -1,4 +1,3 @@ -import json import os from typing import Any, List, Optional @@ -34,6 +33,10 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings): async_client: Any = None #: :meta private: model: Optional[str] = None """Model name to use.""" + provider: Optional[str] = None + """Name of the provider to use for inference with the model specified in + ``repo_id``. e.g. "sambanova". if not specified, defaults to HF Inference API. + available providers can be found in the [huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/guides/inference#supported-providers-and-tasks).""" repo_id: Optional[str] = None """Huggingfacehub repository id, for backward compatibility.""" task: Optional[str] = "feature-extraction" @@ -74,11 +77,13 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings): client = InferenceClient( model=self.model, token=huggingfacehub_api_token, + provider=self.provider, ) async_client = AsyncInferenceClient( model=self.model, token=huggingfacehub_api_token, + provider=self.provider, ) if self.task not in VALID_TASKS: @@ -109,10 +114,8 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings): texts = [text.replace("\n", " ") for text in texts] _model_kwargs = self.model_kwargs or {} # api doc: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/embed - responses = self.client.post( - json={"inputs": texts, **_model_kwargs}, task=self.task - ) - return json.loads(responses.decode()) + responses = self.client.feature_extraction(text=texts, **_model_kwargs) + return responses async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """Async Call to HuggingFaceHub's embedding endpoint for embedding search docs. @@ -126,10 +129,10 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings): # replace newlines, which can negatively affect performance. texts = [text.replace("\n", " ") for text in texts] _model_kwargs = self.model_kwargs or {} - responses = await self.async_client.post( - json={"inputs": texts, **_model_kwargs}, task=self.task + responses = await self.async_client.feature_extraction( + text=texts, **_model_kwargs ) - return json.loads(responses.decode()) + return responses def embed_query(self, text: str) -> List[float]: """Call out to HuggingFaceHub's embedding endpoint for embedding query text. diff --git a/libs/partners/huggingface/pyproject.toml b/libs/partners/huggingface/pyproject.toml index 852ea0a9797..92d267ea1ec 100644 --- a/libs/partners/huggingface/pyproject.toml +++ b/libs/partners/huggingface/pyproject.toml @@ -11,7 +11,7 @@ dependencies = [ "tokenizers>=0.19.1", "transformers>=4.39.0", "sentence-transformers>=2.6.0", - "huggingface-hub>=0.23.0", + "huggingface-hub>=0.30.2", ] name = "langchain-huggingface" version = "0.1.2"