mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 19:12:42 +00:00
partners: (langchain-huggingface) Chat Models - Integrate Hugging Face Inference Providers and remove deprecated code (#30733)
Hi there, I'm Célina from 🤗, This PR introduces support for Hugging Face's serverless Inference Providers (documentation [here](https://huggingface.co/docs/inference-providers/index)), allowing users to specify different providers for chat completion and text generation tasks. This PR also removes the usage of `InferenceClient.post()` method in `HuggingFaceEndpoint`, in favor of the task-specific `text_generation` method. `InferenceClient.post()` is deprecated and will be removed in `huggingface_hub v0.31.0`. --- ## Changes made - bumped the minimum required version of the `huggingface-hub` package to ensure compatibility with the latest API usage. - added a `provider` field to `HuggingFaceEndpoint`, enabling users to select the inference provider (e.g., 'cerebras', 'together', 'fireworks-ai'). Defaults to `hf-inference` (HF Inference API). - replaced the deprecated `InferenceClient.post()` call in `HuggingFaceEndpoint` with the task-specific `text_generation` method for future-proofing, `post()` will be removed in huggingface-hub v0.31.0. - updated the `ChatHuggingFace` component: - added async and streaming support. - added support for tool calling. - exposed underlying chat completion parameters for more granular control. - Added integration tests for `ChatHuggingFace` and updated the corresponding unit tests. ✅ All changes are backward compatible. --------- Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import inspect
|
||||
import json # type: ignore[import-not-found]
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping
|
||||
@@ -27,7 +26,7 @@ VALID_TASKS = (
|
||||
|
||||
class HuggingFaceEndpoint(LLM):
|
||||
"""
|
||||
HuggingFace Endpoint.
|
||||
Hugging Face Endpoint. This works with any model that supports text generation (i.e. text completion) task.
|
||||
|
||||
To use this class, you should have installed the ``huggingface_hub`` package, and
|
||||
the environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token,
|
||||
@@ -67,6 +66,15 @@ class HuggingFaceEndpoint(LLM):
|
||||
)
|
||||
print(llm.invoke("What is Deep Learning?"))
|
||||
|
||||
# Basic Example (no streaming) with Mistral-Nemo-Base-2407 model using a third-party provider (Novita).
|
||||
llm = HuggingFaceEndpoint(
|
||||
repo_id="mistralai/Mistral-Nemo-Base-2407",
|
||||
provider="novita",
|
||||
max_new_tokens=100,
|
||||
do_sample=False,
|
||||
huggingfacehub_api_token="my-api-key"
|
||||
)
|
||||
print(llm.invoke("What is Deep Learning?"))
|
||||
""" # noqa: E501
|
||||
|
||||
endpoint_url: Optional[str] = None
|
||||
@@ -74,6 +82,11 @@ class HuggingFaceEndpoint(LLM):
|
||||
should be pass as env variable in `HF_INFERENCE_ENDPOINT`"""
|
||||
repo_id: Optional[str] = None
|
||||
"""Repo to use. If endpoint_url is not specified then this needs to given"""
|
||||
provider: Optional[str] = None
|
||||
"""Name of the provider to use for inference with the model specified in `repo_id`.
|
||||
e.g. "cerebras". if not specified, Defaults to "auto" i.e. the first of the
|
||||
providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
|
||||
available providers can be found in the [huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/guides/inference#supported-providers-and-tasks)."""
|
||||
huggingfacehub_api_token: Optional[str] = Field(
|
||||
default_factory=from_env("HUGGINGFACEHUB_API_TOKEN", default=None)
|
||||
)
|
||||
@@ -120,8 +133,7 @@ class HuggingFaceEndpoint(LLM):
|
||||
client: Any = None #: :meta private:
|
||||
async_client: Any = None #: :meta private:
|
||||
task: Optional[str] = None
|
||||
"""Task to call the model with.
|
||||
Should be a task that returns `generated_text` or `summary_text`."""
|
||||
"""Task to call the model with. Should be a task that returns `generated_text`."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
@@ -190,36 +202,22 @@ class HuggingFaceEndpoint(LLM):
|
||||
@model_validator(mode="after")
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate that package is installed and that the API token is valid."""
|
||||
try:
|
||||
from huggingface_hub import login # type: ignore[import]
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import huggingface_hub python package. "
|
||||
"Please install it with `pip install huggingface_hub`."
|
||||
)
|
||||
|
||||
huggingfacehub_api_token = self.huggingfacehub_api_token or os.getenv(
|
||||
"HF_TOKEN"
|
||||
)
|
||||
|
||||
if huggingfacehub_api_token is not None:
|
||||
try:
|
||||
login(token=huggingfacehub_api_token)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Could not authenticate with huggingface_hub. "
|
||||
"Please check your API token."
|
||||
) from e
|
||||
|
||||
from huggingface_hub import AsyncInferenceClient, InferenceClient
|
||||
from huggingface_hub import ( # type: ignore[import]
|
||||
AsyncInferenceClient, # type: ignore[import]
|
||||
InferenceClient, # type: ignore[import]
|
||||
)
|
||||
|
||||
# Instantiate clients with supported kwargs
|
||||
sync_supported_kwargs = set(inspect.signature(InferenceClient).parameters)
|
||||
self.client = InferenceClient(
|
||||
model=self.model,
|
||||
timeout=self.timeout,
|
||||
token=huggingfacehub_api_token,
|
||||
api_key=huggingfacehub_api_token,
|
||||
provider=self.provider, # type: ignore[arg-type]
|
||||
**{
|
||||
key: value
|
||||
for key, value in self.server_kwargs.items()
|
||||
@@ -231,14 +229,14 @@ class HuggingFaceEndpoint(LLM):
|
||||
self.async_client = AsyncInferenceClient(
|
||||
model=self.model,
|
||||
timeout=self.timeout,
|
||||
token=huggingfacehub_api_token,
|
||||
api_key=huggingfacehub_api_token,
|
||||
provider=self.provider, # type: ignore[arg-type]
|
||||
**{
|
||||
key: value
|
||||
for key, value in self.server_kwargs.items()
|
||||
if key in async_supported_kwargs
|
||||
},
|
||||
)
|
||||
|
||||
ignored_kwargs = (
|
||||
set(self.server_kwargs.keys())
|
||||
- sync_supported_kwargs
|
||||
@@ -264,7 +262,7 @@ class HuggingFaceEndpoint(LLM):
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
"return_full_text": self.return_full_text,
|
||||
"truncate": self.truncate,
|
||||
"stop_sequences": self.stop_sequences,
|
||||
"stop": self.stop_sequences,
|
||||
"seed": self.seed,
|
||||
"do_sample": self.do_sample,
|
||||
"watermark": self.watermark,
|
||||
@@ -276,7 +274,11 @@ class HuggingFaceEndpoint(LLM):
|
||||
"""Get the identifying parameters."""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
return {
|
||||
**{"endpoint_url": self.endpoint_url, "task": self.task},
|
||||
**{
|
||||
"endpoint_url": self.endpoint_url,
|
||||
"task": self.task,
|
||||
"provider": self.provider,
|
||||
},
|
||||
**{"model_kwargs": _model_kwargs},
|
||||
}
|
||||
|
||||
@@ -289,7 +291,7 @@ class HuggingFaceEndpoint(LLM):
|
||||
self, runtime_stop: Optional[list[str]], **kwargs: Any
|
||||
) -> dict[str, Any]:
|
||||
params = {**self._default_params, **kwargs}
|
||||
params["stop_sequences"] = params["stop_sequences"] + (runtime_stop or [])
|
||||
params["stop"] = params["stop"] + (runtime_stop or [])
|
||||
return params
|
||||
|
||||
def _call(
|
||||
@@ -307,19 +309,15 @@ class HuggingFaceEndpoint(LLM):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
else:
|
||||
invocation_params["stop"] = invocation_params[
|
||||
"stop_sequences"
|
||||
] # porting 'stop_sequences' into the 'stop' argument
|
||||
response = self.client.post(
|
||||
json={"inputs": prompt, "parameters": invocation_params},
|
||||
stream=False,
|
||||
task=self.task,
|
||||
response_text = self.client.text_generation(
|
||||
prompt=prompt,
|
||||
model=self.model,
|
||||
**invocation_params,
|
||||
)
|
||||
response_text = json.loads(response.decode())[0]["generated_text"]
|
||||
|
||||
# Maybe the generation has stopped at one of the stop sequences:
|
||||
# then we remove this stop sequence from the end of the generated text
|
||||
for stop_seq in invocation_params["stop_sequences"]:
|
||||
for stop_seq in invocation_params["stop"]:
|
||||
if response_text[-len(stop_seq) :] == stop_seq:
|
||||
response_text = response_text[: -len(stop_seq)]
|
||||
return response_text
|
||||
@@ -340,17 +338,16 @@ class HuggingFaceEndpoint(LLM):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
else:
|
||||
invocation_params["stop"] = invocation_params["stop_sequences"]
|
||||
response = await self.async_client.post(
|
||||
json={"inputs": prompt, "parameters": invocation_params},
|
||||
response_text = await self.async_client.text_generation(
|
||||
prompt=prompt,
|
||||
**invocation_params,
|
||||
model=self.model,
|
||||
stream=False,
|
||||
task=self.task,
|
||||
)
|
||||
response_text = json.loads(response.decode())[0]["generated_text"]
|
||||
|
||||
# Maybe the generation has stopped at one of the stop sequences:
|
||||
# then remove this stop sequence from the end of the generated text
|
||||
for stop_seq in invocation_params["stop_sequences"]:
|
||||
for stop_seq in invocation_params["stop"]:
|
||||
if response_text[-len(stop_seq) :] == stop_seq:
|
||||
response_text = response_text[: -len(stop_seq)]
|
||||
return response_text
|
||||
@@ -369,7 +366,7 @@ class HuggingFaceEndpoint(LLM):
|
||||
):
|
||||
# identify stop sequence in generated text, if any
|
||||
stop_seq_found: Optional[str] = None
|
||||
for stop_seq in invocation_params["stop_sequences"]:
|
||||
for stop_seq in invocation_params["stop"]:
|
||||
if stop_seq in response:
|
||||
stop_seq_found = stop_seq
|
||||
|
||||
@@ -405,7 +402,7 @@ class HuggingFaceEndpoint(LLM):
|
||||
):
|
||||
# identify stop sequence in generated text, if any
|
||||
stop_seq_found: Optional[str] = None
|
||||
for stop_seq in invocation_params["stop_sequences"]:
|
||||
for stop_seq in invocation_params["stop"]:
|
||||
if stop_seq in response:
|
||||
stop_seq_found = stop_seq
|
||||
|
||||
|
Reference in New Issue
Block a user