mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-14 14:05:37 +00:00
partners: (langchain-huggingface) Chat Models - Integrate Hugging Face Inference Providers and remove deprecated code (#30733)
Hi there, I'm Célina from 🤗, This PR introduces support for Hugging Face's serverless Inference Providers (documentation [here](https://huggingface.co/docs/inference-providers/index)), allowing users to specify different providers for chat completion and text generation tasks. This PR also removes the usage of `InferenceClient.post()` method in `HuggingFaceEndpoint`, in favor of the task-specific `text_generation` method. `InferenceClient.post()` is deprecated and will be removed in `huggingface_hub v0.31.0`. --- ## Changes made - bumped the minimum required version of the `huggingface-hub` package to ensure compatibility with the latest API usage. - added a `provider` field to `HuggingFaceEndpoint`, enabling users to select the inference provider (e.g., 'cerebras', 'together', 'fireworks-ai'). Defaults to `hf-inference` (HF Inference API). - replaced the deprecated `InferenceClient.post()` call in `HuggingFaceEndpoint` with the task-specific `text_generation` method for future-proofing, `post()` will be removed in huggingface-hub v0.31.0. - updated the `ChatHuggingFace` component: - added async and streaming support. - added support for tool calling. - exposed underlying chat completion parameters for more granular control. - Added integration tests for `ChatHuggingFace` and updated the corresponding unit tests. ✅ All changes are backward compatible. --------- Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
@@ -6,7 +6,9 @@ from langchain_huggingface.llms import HuggingFacePipeline
|
||||
def test_huggingface_pipeline_streaming() -> None:
|
||||
"""Test streaming tokens from huggingface_pipeline."""
|
||||
llm = HuggingFacePipeline.from_model_id(
|
||||
model_id="gpt2", task="text-generation", pipeline_kwargs={"max_new_tokens": 10}
|
||||
model_id="openai-community/gpt2",
|
||||
task="text-generation",
|
||||
pipeline_kwargs={"max_new_tokens": 10},
|
||||
)
|
||||
generator = llm.stream("Q: How do you say 'hello' in German? A:'", stop=["."])
|
||||
stream_results_string = ""
|
||||
@@ -15,4 +17,4 @@ def test_huggingface_pipeline_streaming() -> None:
|
||||
for chunk in generator:
|
||||
assert isinstance(chunk, str)
|
||||
stream_results_string = chunk
|
||||
assert len(stream_results_string.strip()) > 1
|
||||
assert len(stream_results_string.strip()) > 0
|
||||
|
@@ -15,70 +15,39 @@ class TestHuggingFaceEndpoint(ChatModelIntegrationTests):
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {}
|
||||
llm = HuggingFaceEndpoint( # type: ignore[call-arg]
|
||||
repo_id="Qwen/Qwen2.5-72B-Instruct",
|
||||
task="conversational",
|
||||
provider="fireworks-ai",
|
||||
temperature=0,
|
||||
)
|
||||
return {"llm": llm}
|
||||
|
||||
@pytest.fixture
|
||||
def model(self) -> BaseChatModel:
|
||||
llm = HuggingFaceEndpoint( # type: ignore[call-arg]
|
||||
repo_id="HuggingFaceH4/zephyr-7b-beta",
|
||||
task="text-generation",
|
||||
max_new_tokens=512,
|
||||
do_sample=False,
|
||||
repetition_penalty=1.03,
|
||||
)
|
||||
return self.chat_model_class(llm=llm) # type: ignore[call-arg]
|
||||
return self.chat_model_class(**self.chat_model_params) # type: ignore[call-arg]
|
||||
|
||||
@pytest.mark.xfail(reason=("Not implemented"))
|
||||
def test_stream(self, model: BaseChatModel) -> None:
|
||||
super().test_stream(model)
|
||||
|
||||
@pytest.mark.xfail(reason=("Not implemented"))
|
||||
async def test_astream(self, model: BaseChatModel) -> None:
|
||||
await super().test_astream(model)
|
||||
|
||||
@pytest.mark.xfail(reason=("Not implemented"))
|
||||
def test_usage_metadata(self, model: BaseChatModel) -> None:
|
||||
super().test_usage_metadata(model)
|
||||
|
||||
@pytest.mark.xfail(reason=("Not implemented"))
|
||||
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
|
||||
super().test_usage_metadata_streaming(model)
|
||||
|
||||
@pytest.mark.xfail(reason=("Not implemented"))
|
||||
def test_stop_sequence(self, model: BaseChatModel) -> None:
|
||||
super().test_stop_sequence(model)
|
||||
|
||||
@pytest.mark.xfail(reason=("Not implemented"))
|
||||
def test_tool_calling(self, model: BaseChatModel) -> None:
|
||||
super().test_tool_calling(model)
|
||||
|
||||
@pytest.mark.xfail(reason=("Not implemented"))
|
||||
async def test_tool_calling_async(self, model: BaseChatModel) -> None:
|
||||
await super().test_tool_calling_async(model)
|
||||
|
||||
@pytest.mark.xfail(reason=("Not implemented"))
|
||||
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
|
||||
super().test_tool_calling_with_no_arguments(model)
|
||||
|
||||
@pytest.mark.xfail(reason=("Not implemented"))
|
||||
def test_bind_runnables_as_tools(self, model: BaseChatModel) -> None:
|
||||
super().test_bind_runnables_as_tools(model)
|
||||
|
||||
@pytest.mark.xfail(reason=("Not implemented"))
|
||||
@pytest.mark.xfail(
|
||||
reason=("Overrding, testing only typed dict and json schema structured output")
|
||||
)
|
||||
@pytest.mark.parametrize("schema_type", ["typeddict", "json_schema"])
|
||||
def test_structured_output(self, model: BaseChatModel, schema_type: str) -> None:
|
||||
super().test_structured_output(model, schema_type)
|
||||
|
||||
@pytest.mark.xfail(reason=("Not implemented"))
|
||||
@pytest.mark.xfail(
|
||||
reason=("Overrding, testing only typed dict and json schema structured output")
|
||||
)
|
||||
@pytest.mark.parametrize("schema_type", ["typeddict", "json_schema"])
|
||||
async def test_structured_output_async(
|
||||
self, model: BaseChatModel, schema_type: str
|
||||
) -> None: # type: ignore[override]
|
||||
super().test_structured_output(model, schema_type)
|
||||
|
||||
@pytest.mark.xfail(reason=("Not implemented"))
|
||||
@pytest.mark.xfail(reason=("Pydantic structured output is not supported"))
|
||||
def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None:
|
||||
super().test_structured_output_pydantic_2_v1(model)
|
||||
|
||||
@pytest.mark.xfail(reason=("Not implemented"))
|
||||
@pytest.mark.xfail(reason=("Pydantic structured output is not supported"))
|
||||
def test_structured_output_optional_param(self, model: BaseChatModel) -> None:
|
||||
super().test_structured_output_optional_param(model)
|
||||
|
||||
@@ -95,3 +64,7 @@ class TestHuggingFaceEndpoint(ChatModelIntegrationTests):
|
||||
self, model: BaseChatModel, my_adder_tool: BaseTool
|
||||
) -> None:
|
||||
super().test_structured_few_shot_examples(model, my_adder_tool=my_adder_tool)
|
||||
|
||||
@property
|
||||
def has_tool_choice(self) -> bool:
|
||||
return False
|
||||
|
@@ -1,11 +1,11 @@
|
||||
from typing import Any # type: ignore[import-not-found]
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest # type: ignore[import-not-found]
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
@@ -13,92 +13,10 @@ from langchain_core.outputs import ChatResult
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain_huggingface.chat_models import ( # type: ignore[import]
|
||||
TGI_MESSAGE,
|
||||
ChatHuggingFace,
|
||||
_convert_message_to_chat_message,
|
||||
_convert_TGI_message_to_LC_message,
|
||||
_convert_dict_to_message,
|
||||
)
|
||||
from langchain_huggingface.llms.huggingface_endpoint import (
|
||||
HuggingFaceEndpoint,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("message", "expected"),
|
||||
[
|
||||
(
|
||||
SystemMessage(content="Hello"),
|
||||
dict(role="system", content="Hello"),
|
||||
),
|
||||
(
|
||||
HumanMessage(content="Hello"),
|
||||
dict(role="user", content="Hello"),
|
||||
),
|
||||
(
|
||||
AIMessage(content="Hello"),
|
||||
dict(role="assistant", content="Hello", tool_calls=None),
|
||||
),
|
||||
(
|
||||
ChatMessage(role="assistant", content="Hello"),
|
||||
dict(role="assistant", content="Hello"),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_convert_message_to_chat_message(
|
||||
message: BaseMessage, expected: dict[str, str]
|
||||
) -> None:
|
||||
result = _convert_message_to_chat_message(message)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("tgi_message", "expected"),
|
||||
[
|
||||
(
|
||||
TGI_MESSAGE(role="assistant", content="Hello", tool_calls=[]),
|
||||
AIMessage(content="Hello"),
|
||||
),
|
||||
(
|
||||
TGI_MESSAGE(role="assistant", content="", tool_calls=[]),
|
||||
AIMessage(content=""),
|
||||
),
|
||||
(
|
||||
TGI_MESSAGE(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[{"function": {"arguments": "function string"}}],
|
||||
),
|
||||
AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [{"function": {"arguments": '"function string"'}}]
|
||||
},
|
||||
),
|
||||
),
|
||||
(
|
||||
TGI_MESSAGE(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"function": {"arguments": {"answer": "function's string"}}}
|
||||
],
|
||||
),
|
||||
AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{"function": {"arguments": '{"answer": "function\'s string"}'}}
|
||||
]
|
||||
},
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_convert_TGI_message_to_LC_message(
|
||||
tgi_message: TGI_MESSAGE, expected: BaseMessage
|
||||
) -> None:
|
||||
result = _convert_TGI_message_to_LC_message(tgi_message)
|
||||
assert result == expected
|
||||
from langchain_huggingface.llms import HuggingFaceEndpoint
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -118,16 +36,15 @@ def chat_hugging_face(mock_resolve_id: Any, mock_llm: Any) -> ChatHuggingFace:
|
||||
|
||||
|
||||
def test_create_chat_result(chat_hugging_face: Any) -> None:
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [
|
||||
MagicMock(
|
||||
message=TGI_MESSAGE(
|
||||
role="assistant", content="test message", tool_calls=[]
|
||||
),
|
||||
finish_reason="test finish reason",
|
||||
)
|
||||
]
|
||||
mock_response.usage = {"tokens": 420}
|
||||
mock_response = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {"role": "assistant", "content": "test message"},
|
||||
"finish_reason": "test finish reason",
|
||||
}
|
||||
],
|
||||
"usage": {"tokens": 420},
|
||||
}
|
||||
|
||||
result = chat_hugging_face._create_chat_result(mock_response)
|
||||
assert isinstance(result, ChatResult)
|
||||
@@ -136,7 +53,7 @@ def test_create_chat_result(chat_hugging_face: Any) -> None:
|
||||
result.generations[0].generation_info["finish_reason"] == "test finish reason" # type: ignore[index]
|
||||
)
|
||||
assert result.llm_output["token_usage"]["tokens"] == 420 # type: ignore[index]
|
||||
assert result.llm_output["model"] == chat_hugging_face.llm.inference_server_url # type: ignore[index]
|
||||
assert result.llm_output["model_name"] == chat_hugging_face.model_id # type: ignore[index]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -207,6 +124,39 @@ def test_to_chatml_format_with_invalid_type(chat_hugging_face: Any) -> None:
|
||||
assert "Unknown message type:" in str(e.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("msg_dict", "expected_type", "expected_content"),
|
||||
[
|
||||
(
|
||||
{"role": "system", "content": "You are helpful"},
|
||||
SystemMessage,
|
||||
"You are helpful",
|
||||
),
|
||||
(
|
||||
{"role": "user", "content": "Hello there"},
|
||||
HumanMessage,
|
||||
"Hello there",
|
||||
),
|
||||
(
|
||||
{"role": "assistant", "content": "How can I help?"},
|
||||
AIMessage,
|
||||
"How can I help?",
|
||||
),
|
||||
(
|
||||
{"role": "function", "content": "result", "name": "get_time"},
|
||||
FunctionMessage,
|
||||
"result",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_convert_dict_to_message(
|
||||
msg_dict: dict[str, Any], expected_type: type, expected_content: str
|
||||
) -> None:
|
||||
result = _convert_dict_to_message(msg_dict)
|
||||
assert isinstance(result, expected_type)
|
||||
assert result.content == expected_content
|
||||
|
||||
|
||||
def tool_mock() -> dict:
|
||||
return {"function": {"name": "test_tool"}}
|
||||
|
||||
|
Reference in New Issue
Block a user