mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 08:33:49 +00:00
partners: (langchain-huggingface) Embeddings - Integrate Inference Providers and remove deprecated code (#30735)
Hi there, This is a complementary PR to #30733. This PR introduces support for Hugging Face's serverless Inference Providers (documentation [here](https://huggingface.co/docs/inference-providers/index)), allowing users to specify different providers This PR also removes the usage of `InferenceClient.post()` method in `HuggingFaceEndpointEmbeddings`, in favor of the task-specific `feature_extraction` method. `InferenceClient.post()` is deprecated and will be removed in `huggingface_hub` v0.31.0. ## Changes made - bumped the minimum required version of the `huggingface_hub` package to ensure compatibility with the latest API usage. - added a provider field to `HuggingFaceEndpointEmbeddings`, enabling users to select the inference provider. - replaced the deprecated `InferenceClient.post()` call in `HuggingFaceEndpointEmbeddings` with the task-specific `feature_extraction` method for future-proofing, `post()` will be removed in `huggingface-hub` v0.31.0. ✅ All changes are backward compatible. --------- Co-authored-by: Lucain <lucainp@gmail.com> Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
parent
98f0016fc2
commit
68361f9c2d
@ -1,4 +1,3 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, List, Optional
|
||||
|
||||
@ -34,6 +33,10 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
async_client: Any = None #: :meta private:
|
||||
model: Optional[str] = None
|
||||
"""Model name to use."""
|
||||
provider: Optional[str] = None
|
||||
"""Name of the provider to use for inference with the model specified in
|
||||
``repo_id``. e.g. "sambanova". if not specified, defaults to HF Inference API.
|
||||
available providers can be found in the [huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/guides/inference#supported-providers-and-tasks)."""
|
||||
repo_id: Optional[str] = None
|
||||
"""Huggingfacehub repository id, for backward compatibility."""
|
||||
task: Optional[str] = "feature-extraction"
|
||||
@ -74,11 +77,13 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
client = InferenceClient(
|
||||
model=self.model,
|
||||
token=huggingfacehub_api_token,
|
||||
provider=self.provider,
|
||||
)
|
||||
|
||||
async_client = AsyncInferenceClient(
|
||||
model=self.model,
|
||||
token=huggingfacehub_api_token,
|
||||
provider=self.provider,
|
||||
)
|
||||
|
||||
if self.task not in VALID_TASKS:
|
||||
@ -109,10 +114,8 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
texts = [text.replace("\n", " ") for text in texts]
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
# api doc: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/embed
|
||||
responses = self.client.post(
|
||||
json={"inputs": texts, **_model_kwargs}, task=self.task
|
||||
)
|
||||
return json.loads(responses.decode())
|
||||
responses = self.client.feature_extraction(text=texts, **_model_kwargs)
|
||||
return responses
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Async Call to HuggingFaceHub's embedding endpoint for embedding search docs.
|
||||
@ -126,10 +129,10 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
# replace newlines, which can negatively affect performance.
|
||||
texts = [text.replace("\n", " ") for text in texts]
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
responses = await self.async_client.post(
|
||||
json={"inputs": texts, **_model_kwargs}, task=self.task
|
||||
responses = await self.async_client.feature_extraction(
|
||||
text=texts, **_model_kwargs
|
||||
)
|
||||
return json.loads(responses.decode())
|
||||
return responses
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to HuggingFaceHub's embedding endpoint for embedding query text.
|
||||
|
@ -11,7 +11,7 @@ dependencies = [
|
||||
"tokenizers>=0.19.1",
|
||||
"transformers>=4.39.0",
|
||||
"sentence-transformers>=2.6.0",
|
||||
"huggingface-hub>=0.23.0",
|
||||
"huggingface-hub>=0.30.2",
|
||||
]
|
||||
name = "langchain-huggingface"
|
||||
version = "0.1.2"
|
||||
|
Loading…
Reference in New Issue
Block a user