partners: (langchain-huggingface) Chat Models - Integrate Hugging Face Inference Providers and remove deprecated code (#30733)

Hi there, I'm Célina from 🤗,
This PR introduces support for Hugging Face's serverless Inference
Providers (documentation
[here](https://huggingface.co/docs/inference-providers/index)), allowing
users to specify different providers for chat completion and text
generation tasks.

This PR also removes the usage of `InferenceClient.post()` method in
`HuggingFaceEndpoint`, in favor of the task-specific `text_generation`
method. `InferenceClient.post()` is deprecated and will be removed in
`huggingface_hub v0.31.0`.

---
## Changes made
- bumped the minimum required version of the `huggingface-hub` package
to ensure compatibility with the latest API usage.
- added a `provider` field to `HuggingFaceEndpoint`, enabling users to
select the inference provider (e.g., 'cerebras', 'together',
'fireworks-ai'). Defaults to `hf-inference` (HF Inference API).
- replaced the deprecated `InferenceClient.post()` call in
`HuggingFaceEndpoint` with the task-specific `text_generation` method
for future-proofing, `post()` will be removed in huggingface-hub
v0.31.0.
- updated the `ChatHuggingFace` component:
    - added async and streaming support.
    - added support for tool calling.
- exposed underlying chat completion parameters for more granular
control.
- Added integration tests for `ChatHuggingFace` and updated the
corresponding unit tests.

  All changes are backward compatible.

---------

Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
célina
2025-04-29 15:53:14 +02:00
committed by GitHub
parent 3072e4610a
commit 868f07f8f4
8 changed files with 699 additions and 504 deletions

View File

@@ -1,5 +1,4 @@
import inspect
import json # type: ignore[import-not-found]
import logging
import os
from collections.abc import AsyncIterator, Iterator, Mapping
@@ -27,7 +26,7 @@ VALID_TASKS = (
class HuggingFaceEndpoint(LLM):
"""
HuggingFace Endpoint.
Hugging Face Endpoint. This works with any model that supports text generation (i.e. text completion) task.
To use this class, you should have installed the ``huggingface_hub`` package, and
the environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token,
@@ -67,6 +66,15 @@ class HuggingFaceEndpoint(LLM):
)
print(llm.invoke("What is Deep Learning?"))
# Basic Example (no streaming) with Mistral-Nemo-Base-2407 model using a third-party provider (Novita).
llm = HuggingFaceEndpoint(
repo_id="mistralai/Mistral-Nemo-Base-2407",
provider="novita",
max_new_tokens=100,
do_sample=False,
huggingfacehub_api_token="my-api-key"
)
print(llm.invoke("What is Deep Learning?"))
""" # noqa: E501
endpoint_url: Optional[str] = None
@@ -74,6 +82,11 @@ class HuggingFaceEndpoint(LLM):
should be pass as env variable in `HF_INFERENCE_ENDPOINT`"""
repo_id: Optional[str] = None
"""Repo to use. If endpoint_url is not specified then this needs to given"""
provider: Optional[str] = None
"""Name of the provider to use for inference with the model specified in `repo_id`.
e.g. "cerebras". if not specified, Defaults to "auto" i.e. the first of the
providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
available providers can be found in the [huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/guides/inference#supported-providers-and-tasks)."""
huggingfacehub_api_token: Optional[str] = Field(
default_factory=from_env("HUGGINGFACEHUB_API_TOKEN", default=None)
)
@@ -120,8 +133,7 @@ class HuggingFaceEndpoint(LLM):
client: Any = None #: :meta private:
async_client: Any = None #: :meta private:
task: Optional[str] = None
"""Task to call the model with.
Should be a task that returns `generated_text` or `summary_text`."""
"""Task to call the model with. Should be a task that returns `generated_text`."""
model_config = ConfigDict(
extra="forbid",
@@ -190,36 +202,22 @@ class HuggingFaceEndpoint(LLM):
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that package is installed and that the API token is valid."""
try:
from huggingface_hub import login # type: ignore[import]
except ImportError:
raise ImportError(
"Could not import huggingface_hub python package. "
"Please install it with `pip install huggingface_hub`."
)
huggingfacehub_api_token = self.huggingfacehub_api_token or os.getenv(
"HF_TOKEN"
)
if huggingfacehub_api_token is not None:
try:
login(token=huggingfacehub_api_token)
except Exception as e:
raise ValueError(
"Could not authenticate with huggingface_hub. "
"Please check your API token."
) from e
from huggingface_hub import AsyncInferenceClient, InferenceClient
from huggingface_hub import ( # type: ignore[import]
AsyncInferenceClient, # type: ignore[import]
InferenceClient, # type: ignore[import]
)
# Instantiate clients with supported kwargs
sync_supported_kwargs = set(inspect.signature(InferenceClient).parameters)
self.client = InferenceClient(
model=self.model,
timeout=self.timeout,
token=huggingfacehub_api_token,
api_key=huggingfacehub_api_token,
provider=self.provider, # type: ignore[arg-type]
**{
key: value
for key, value in self.server_kwargs.items()
@@ -231,14 +229,14 @@ class HuggingFaceEndpoint(LLM):
self.async_client = AsyncInferenceClient(
model=self.model,
timeout=self.timeout,
token=huggingfacehub_api_token,
api_key=huggingfacehub_api_token,
provider=self.provider, # type: ignore[arg-type]
**{
key: value
for key, value in self.server_kwargs.items()
if key in async_supported_kwargs
},
)
ignored_kwargs = (
set(self.server_kwargs.keys())
- sync_supported_kwargs
@@ -264,7 +262,7 @@ class HuggingFaceEndpoint(LLM):
"repetition_penalty": self.repetition_penalty,
"return_full_text": self.return_full_text,
"truncate": self.truncate,
"stop_sequences": self.stop_sequences,
"stop": self.stop_sequences,
"seed": self.seed,
"do_sample": self.do_sample,
"watermark": self.watermark,
@@ -276,7 +274,11 @@ class HuggingFaceEndpoint(LLM):
"""Get the identifying parameters."""
_model_kwargs = self.model_kwargs or {}
return {
**{"endpoint_url": self.endpoint_url, "task": self.task},
**{
"endpoint_url": self.endpoint_url,
"task": self.task,
"provider": self.provider,
},
**{"model_kwargs": _model_kwargs},
}
@@ -289,7 +291,7 @@ class HuggingFaceEndpoint(LLM):
self, runtime_stop: Optional[list[str]], **kwargs: Any
) -> dict[str, Any]:
params = {**self._default_params, **kwargs}
params["stop_sequences"] = params["stop_sequences"] + (runtime_stop or [])
params["stop"] = params["stop"] + (runtime_stop or [])
return params
def _call(
@@ -307,19 +309,15 @@ class HuggingFaceEndpoint(LLM):
completion += chunk.text
return completion
else:
invocation_params["stop"] = invocation_params[
"stop_sequences"
] # porting 'stop_sequences' into the 'stop' argument
response = self.client.post(
json={"inputs": prompt, "parameters": invocation_params},
stream=False,
task=self.task,
response_text = self.client.text_generation(
prompt=prompt,
model=self.model,
**invocation_params,
)
response_text = json.loads(response.decode())[0]["generated_text"]
# Maybe the generation has stopped at one of the stop sequences:
# then we remove this stop sequence from the end of the generated text
for stop_seq in invocation_params["stop_sequences"]:
for stop_seq in invocation_params["stop"]:
if response_text[-len(stop_seq) :] == stop_seq:
response_text = response_text[: -len(stop_seq)]
return response_text
@@ -340,17 +338,16 @@ class HuggingFaceEndpoint(LLM):
completion += chunk.text
return completion
else:
invocation_params["stop"] = invocation_params["stop_sequences"]
response = await self.async_client.post(
json={"inputs": prompt, "parameters": invocation_params},
response_text = await self.async_client.text_generation(
prompt=prompt,
**invocation_params,
model=self.model,
stream=False,
task=self.task,
)
response_text = json.loads(response.decode())[0]["generated_text"]
# Maybe the generation has stopped at one of the stop sequences:
# then remove this stop sequence from the end of the generated text
for stop_seq in invocation_params["stop_sequences"]:
for stop_seq in invocation_params["stop"]:
if response_text[-len(stop_seq) :] == stop_seq:
response_text = response_text[: -len(stop_seq)]
return response_text
@@ -369,7 +366,7 @@ class HuggingFaceEndpoint(LLM):
):
# identify stop sequence in generated text, if any
stop_seq_found: Optional[str] = None
for stop_seq in invocation_params["stop_sequences"]:
for stop_seq in invocation_params["stop"]:
if stop_seq in response:
stop_seq_found = stop_seq
@@ -405,7 +402,7 @@ class HuggingFaceEndpoint(LLM):
):
# identify stop sequence in generated text, if any
stop_seq_found: Optional[str] = None
for stop_seq in invocation_params["stop_sequences"]:
for stop_seq in invocation_params["stop"]:
if stop_seq in response:
stop_seq_found = stop_seq