partners: (langchain-huggingface) Chat Models - Integrate Hugging Face Inference Providers and remove deprecated code (#30733)

Hi there, I'm Célina from 🤗,
This PR introduces support for Hugging Face's serverless Inference
Providers (documentation
[here](https://huggingface.co/docs/inference-providers/index)), allowing
users to specify different providers for chat completion and text
generation tasks.

This PR also removes the usage of `InferenceClient.post()` method in
`HuggingFaceEndpoint`, in favor of the task-specific `text_generation`
method. `InferenceClient.post()` is deprecated and will be removed in
`huggingface_hub v0.31.0`.

---
## Changes made
- bumped the minimum required version of the `huggingface-hub` package
to ensure compatibility with the latest API usage.
- added a `provider` field to `HuggingFaceEndpoint`, enabling users to
select the inference provider (e.g., 'cerebras', 'together',
'fireworks-ai'). Defaults to `hf-inference` (HF Inference API).
- replaced the deprecated `InferenceClient.post()` call in
`HuggingFaceEndpoint` with the task-specific `text_generation` method
for future-proofing, `post()` will be removed in huggingface-hub
v0.31.0.
- updated the `ChatHuggingFace` component:
    - added async and streaming support.
    - added support for tool calling.
- exposed underlying chat completion parameters for more granular
control.
- Added integration tests for `ChatHuggingFace` and updated the
corresponding unit tests.

  All changes are backward compatible.

---------

Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
célina
2025-04-29 15:53:14 +02:00
committed by GitHub
parent 3072e4610a
commit 868f07f8f4
8 changed files with 699 additions and 504 deletions

View File

@@ -1,11 +1,11 @@
from typing import Any # type: ignore[import-not-found]
from typing import Any
from unittest.mock import MagicMock, Mock, patch
import pytest # type: ignore[import-not-found]
from langchain_core.messages import (
AIMessage,
BaseMessage,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
@@ -13,92 +13,10 @@ from langchain_core.outputs import ChatResult
from langchain_core.tools import BaseTool
from langchain_huggingface.chat_models import ( # type: ignore[import]
TGI_MESSAGE,
ChatHuggingFace,
_convert_message_to_chat_message,
_convert_TGI_message_to_LC_message,
_convert_dict_to_message,
)
from langchain_huggingface.llms.huggingface_endpoint import (
HuggingFaceEndpoint,
)
@pytest.mark.parametrize(
("message", "expected"),
[
(
SystemMessage(content="Hello"),
dict(role="system", content="Hello"),
),
(
HumanMessage(content="Hello"),
dict(role="user", content="Hello"),
),
(
AIMessage(content="Hello"),
dict(role="assistant", content="Hello", tool_calls=None),
),
(
ChatMessage(role="assistant", content="Hello"),
dict(role="assistant", content="Hello"),
),
],
)
def test_convert_message_to_chat_message(
message: BaseMessage, expected: dict[str, str]
) -> None:
result = _convert_message_to_chat_message(message)
assert result == expected
@pytest.mark.parametrize(
("tgi_message", "expected"),
[
(
TGI_MESSAGE(role="assistant", content="Hello", tool_calls=[]),
AIMessage(content="Hello"),
),
(
TGI_MESSAGE(role="assistant", content="", tool_calls=[]),
AIMessage(content=""),
),
(
TGI_MESSAGE(
role="assistant",
content="",
tool_calls=[{"function": {"arguments": "function string"}}],
),
AIMessage(
content="",
additional_kwargs={
"tool_calls": [{"function": {"arguments": '"function string"'}}]
},
),
),
(
TGI_MESSAGE(
role="assistant",
content="",
tool_calls=[
{"function": {"arguments": {"answer": "function's string"}}}
],
),
AIMessage(
content="",
additional_kwargs={
"tool_calls": [
{"function": {"arguments": '{"answer": "function\'s string"}'}}
]
},
),
),
],
)
def test_convert_TGI_message_to_LC_message(
tgi_message: TGI_MESSAGE, expected: BaseMessage
) -> None:
result = _convert_TGI_message_to_LC_message(tgi_message)
assert result == expected
from langchain_huggingface.llms import HuggingFaceEndpoint
@pytest.fixture
@@ -118,16 +36,15 @@ def chat_hugging_face(mock_resolve_id: Any, mock_llm: Any) -> ChatHuggingFace:
def test_create_chat_result(chat_hugging_face: Any) -> None:
mock_response = MagicMock()
mock_response.choices = [
MagicMock(
message=TGI_MESSAGE(
role="assistant", content="test message", tool_calls=[]
),
finish_reason="test finish reason",
)
]
mock_response.usage = {"tokens": 420}
mock_response = {
"choices": [
{
"message": {"role": "assistant", "content": "test message"},
"finish_reason": "test finish reason",
}
],
"usage": {"tokens": 420},
}
result = chat_hugging_face._create_chat_result(mock_response)
assert isinstance(result, ChatResult)
@@ -136,7 +53,7 @@ def test_create_chat_result(chat_hugging_face: Any) -> None:
result.generations[0].generation_info["finish_reason"] == "test finish reason" # type: ignore[index]
)
assert result.llm_output["token_usage"]["tokens"] == 420 # type: ignore[index]
assert result.llm_output["model"] == chat_hugging_face.llm.inference_server_url # type: ignore[index]
assert result.llm_output["model_name"] == chat_hugging_face.model_id # type: ignore[index]
@pytest.mark.parametrize(
@@ -207,6 +124,39 @@ def test_to_chatml_format_with_invalid_type(chat_hugging_face: Any) -> None:
assert "Unknown message type:" in str(e.value)
@pytest.mark.parametrize(
("msg_dict", "expected_type", "expected_content"),
[
(
{"role": "system", "content": "You are helpful"},
SystemMessage,
"You are helpful",
),
(
{"role": "user", "content": "Hello there"},
HumanMessage,
"Hello there",
),
(
{"role": "assistant", "content": "How can I help?"},
AIMessage,
"How can I help?",
),
(
{"role": "function", "content": "result", "name": "get_time"},
FunctionMessage,
"result",
),
],
)
def test_convert_dict_to_message(
msg_dict: dict[str, Any], expected_type: type, expected_content: str
) -> None:
result = _convert_dict_to_message(msg_dict)
assert isinstance(result, expected_type)
assert result.content == expected_content
def tool_mock() -> dict:
return {"function": {"name": "test_tool"}}