diff --git a/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface_endpoint.py b/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface_endpoint.py index baef1967ee1..896e0fde3b8 100644 --- a/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface_endpoint.py +++ b/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface_endpoint.py @@ -1,4 +1,3 @@ -import json import os from typing import Any, List, Optional @@ -34,6 +33,10 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings): async_client: Any = None #: :meta private: model: Optional[str] = None """Model name to use.""" + provider: Optional[str] = None + """Name of the provider to use for inference with the model specified in + ``repo_id``. e.g. "sambanova". if not specified, defaults to HF Inference API. + available providers can be found in the [huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/guides/inference#supported-providers-and-tasks).""" repo_id: Optional[str] = None """Huggingfacehub repository id, for backward compatibility.""" task: Optional[str] = "feature-extraction" @@ -74,11 +77,13 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings): client = InferenceClient( model=self.model, token=huggingfacehub_api_token, + provider=self.provider, ) async_client = AsyncInferenceClient( model=self.model, token=huggingfacehub_api_token, + provider=self.provider, ) if self.task not in VALID_TASKS: @@ -109,10 +114,8 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings): texts = [text.replace("\n", " ") for text in texts] _model_kwargs = self.model_kwargs or {} # api doc: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/embed - responses = self.client.post( - json={"inputs": texts, **_model_kwargs}, task=self.task - ) - return json.loads(responses.decode()) + responses = self.client.feature_extraction(text=texts, **_model_kwargs) + return responses async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """Async Call to HuggingFaceHub's embedding endpoint for embedding search docs. @@ -126,10 +129,10 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings): # replace newlines, which can negatively affect performance. texts = [text.replace("\n", " ") for text in texts] _model_kwargs = self.model_kwargs or {} - responses = await self.async_client.post( - json={"inputs": texts, **_model_kwargs}, task=self.task + responses = await self.async_client.feature_extraction( + text=texts, **_model_kwargs ) - return json.loads(responses.decode()) + return responses def embed_query(self, text: str) -> List[float]: """Call out to HuggingFaceHub's embedding endpoint for embedding query text. diff --git a/libs/partners/huggingface/pyproject.toml b/libs/partners/huggingface/pyproject.toml index 852ea0a9797..92d267ea1ec 100644 --- a/libs/partners/huggingface/pyproject.toml +++ b/libs/partners/huggingface/pyproject.toml @@ -11,7 +11,7 @@ dependencies = [ "tokenizers>=0.19.1", "transformers>=4.39.0", "sentence-transformers>=2.6.0", - "huggingface-hub>=0.23.0", + "huggingface-hub>=0.30.2", ] name = "langchain-huggingface" version = "0.1.2"