partners: (langchain-huggingface) Chat Models - Integrate Hugging Face Inference Providers and remove deprecated code (#30733)

Hi there, I'm Célina from 🤗,
This PR introduces support for Hugging Face's serverless Inference
Providers (documentation
[here](https://huggingface.co/docs/inference-providers/index)), allowing
users to specify different providers for chat completion and text
generation tasks.

This PR also removes the usage of `InferenceClient.post()` method in
`HuggingFaceEndpoint`, in favor of the task-specific `text_generation`
method. `InferenceClient.post()` is deprecated and will be removed in
`huggingface_hub v0.31.0`.

---
## Changes made
- bumped the minimum required version of the `huggingface-hub` package
to ensure compatibility with the latest API usage.
- added a `provider` field to `HuggingFaceEndpoint`, enabling users to
select the inference provider (e.g., 'cerebras', 'together',
'fireworks-ai'). Defaults to `hf-inference` (HF Inference API).
- replaced the deprecated `InferenceClient.post()` call in
`HuggingFaceEndpoint` with the task-specific `text_generation` method
for future-proofing, `post()` will be removed in huggingface-hub
v0.31.0.
- updated the `ChatHuggingFace` component:
    - added async and streaming support.
    - added support for tool calling.
- exposed underlying chat completion parameters for more granular
control.
- Added integration tests for `ChatHuggingFace` and updated the
corresponding unit tests.

  All changes are backward compatible.

---------

Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
célina 2025-04-29 15:53:14 +02:00 committed by GitHub
parent 3072e4610a
commit 868f07f8f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 699 additions and 504 deletions

View File

@ -2,14 +2,7 @@ from langchain_huggingface.chat_models.huggingface import ( # type: ignore[impo
TGI_MESSAGE, TGI_MESSAGE,
TGI_RESPONSE, TGI_RESPONSE,
ChatHuggingFace, ChatHuggingFace,
_convert_message_to_chat_message, _convert_dict_to_message,
_convert_TGI_message_to_LC_message,
) )
__all__ = [ __all__ = ["ChatHuggingFace", "_convert_dict_to_message", "TGI_MESSAGE", "TGI_RESPONSE"]
"ChatHuggingFace",
"_convert_message_to_chat_message",
"_convert_TGI_message_to_LC_message",
"TGI_MESSAGE",
"TGI_RESPONSE",
]

View File

@ -1,42 +1,65 @@
"""Hugging Face Chat Wrapper.""" """Hugging Face Chat Wrapper."""
import json import json
from collections.abc import Sequence from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from operator import itemgetter
Any, from typing import Any, Callable, Literal, Optional, Union, cast
Callable,
Literal,
Optional,
Union,
cast,
)
from langchain_core.callbacks.manager import ( from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain_core.language_models import LanguageModelInput from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
AIMessageChunk,
BaseMessage, BaseMessage,
BaseMessageChunk,
ChatMessage, ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage, HumanMessage,
HumanMessageChunk,
InvalidToolCall,
SystemMessage, SystemMessage,
SystemMessageChunk,
ToolCall,
ToolMessage, ToolMessage,
ToolMessageChunk,
) )
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult from langchain_core.messages.tool import ToolCallChunk
from langchain_core.runnables import Runnable from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
make_invalid_tool_call,
parse_tool_call,
)
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
LLMResult,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.function_calling import (
from pydantic import model_validator convert_to_json_schema,
convert_to_openai_tool,
)
from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self from typing_extensions import Self
from langchain_huggingface.llms.huggingface_endpoint import HuggingFaceEndpoint from ..llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain_huggingface.llms.huggingface_pipeline import HuggingFacePipeline from ..llms.huggingface_pipeline import HuggingFacePipeline
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."""
@dataclass @dataclass
@ -56,66 +79,143 @@ class TGI_MESSAGE:
tool_calls: list[dict] tool_calls: list[dict]
def _convert_message_to_chat_message( def _lc_tool_call_to_hf_tool_call(tool_call: ToolCall) -> dict:
message: BaseMessage, return {
"type": "function",
"id": tool_call["id"],
"function": {
"name": tool_call["name"],
"arguments": json.dumps(tool_call["args"]),
},
}
def _lc_invalid_tool_call_to_hf_tool_call(
invalid_tool_call: InvalidToolCall,
) -> dict: ) -> dict:
return {
"type": "function",
"id": invalid_tool_call["id"],
"function": {
"name": invalid_tool_call["name"],
"arguments": invalid_tool_call["args"],
},
}
def _convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a LangChain message to a dictionary.
Args:
message: The LangChain message.
Returns:
The dictionary.
"""
message_dict: dict[str, Any]
if isinstance(message, ChatMessage): if isinstance(message, ChatMessage):
return dict(role=message.role, content=message.content) message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage): elif isinstance(message, HumanMessage):
return dict(role="user", content=message.content) message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage): elif isinstance(message, AIMessage):
if "tool_calls" in message.additional_kwargs: message_dict = {"role": "assistant", "content": message.content}
tool_calls = [ if "function_call" in message.additional_kwargs:
{ message_dict["function_call"] = message.additional_kwargs["function_call"]
"function": { # If function call only, content is None not empty string
"name": tc["function"]["name"], if message_dict["content"] == "":
"arguments": tc["function"]["arguments"], message_dict["content"] = None
} if message.tool_calls or message.invalid_tool_calls:
} message_dict["tool_calls"] = [
for tc in message.additional_kwargs["tool_calls"] _lc_tool_call_to_hf_tool_call(tc) for tc in message.tool_calls
] + [
_lc_invalid_tool_call_to_hf_tool_call(tc)
for tc in message.invalid_tool_calls
] ]
elif "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
# If tool calls only, content is None not empty string
if "tool_calls" in message_dict and message_dict["content"] == "":
message_dict["content"] = None
else: else:
tool_calls = None pass
return {
"role": "assistant",
"content": message.content,
"tool_calls": tool_calls,
}
elif isinstance(message, SystemMessage): elif isinstance(message, SystemMessage):
return dict(role="system", content=message.content) message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolMessage): elif isinstance(message, FunctionMessage):
return { message_dict = {
"role": "tool", "role": "function",
"content": message.content, "content": message.content,
"name": message.name, "name": message.name,
} }
elif isinstance(message, ToolMessage):
message_dict = {
"role": "tool",
"content": message.content,
"tool_call_id": message.tool_call_id,
}
else: else:
raise ValueError(f"Got unknown type {message}") raise TypeError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict
def _convert_TGI_message_to_LC_message( def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
_message: TGI_MESSAGE, """Convert a dictionary to a LangChain message.
) -> BaseMessage:
role = _message.role Args:
assert role == "assistant", f"Expected role to be 'assistant', got {role}" _dict: The dictionary.
content = cast(str, _message.content)
if content is None: Returns:
content = "" The LangChain message.
additional_kwargs: dict = {} """
if tool_calls := _message.tool_calls: role = _dict.get("role")
if "arguments" in tool_calls[0]["function"]: if role == "user":
functions = tool_calls[0]["function"].pop("arguments") return HumanMessage(content=_dict.get("content", ""))
tool_calls[0]["function"]["arguments"] = json.dumps( elif role == "assistant":
functions, ensure_ascii=False content = _dict.get("content", "") or ""
) additional_kwargs: dict = {}
additional_kwargs["tool_calls"] = tool_calls if function_call := _dict.get("function_call"):
return AIMessage(content=content, additional_kwargs=additional_kwargs) additional_kwargs["function_call"] = dict(function_call)
tool_calls = []
invalid_tool_calls = []
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
for raw_tool_call in raw_tool_calls:
try:
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
except Exception as e:
invalid_tool_calls.append(
dict(make_invalid_tool_call(raw_tool_call, str(e)))
)
return AIMessage(
content=content,
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
elif role == "system":
return SystemMessage(content=_dict.get("content", ""))
elif role == "function":
return FunctionMessage(
content=_dict.get("content", ""), name=_dict.get("name", "")
)
elif role == "tool":
additional_kwargs = {}
if "name" in _dict:
additional_kwargs["name"] = _dict["name"]
return ToolMessage(
content=_dict.get("content", ""),
tool_call_id=_dict.get("tool_call_id", ""),
additional_kwargs=additional_kwargs,
)
else:
return ChatMessage(content=_dict.get("content", ""), role=role or "")
def _is_huggingface_hub(llm: Any) -> bool: def _is_huggingface_hub(llm: Any) -> bool:
try: try:
from langchain_community.llms.huggingface_hub import ( # type: ignore[import-not-found] from langchain_community.llms.huggingface_hub import (
HuggingFaceHub, HuggingFaceHub, # type: ignore[import-not-found]
) )
return isinstance(llm, HuggingFaceHub) return isinstance(llm, HuggingFaceHub)
@ -124,10 +224,69 @@ def _is_huggingface_hub(llm: Any) -> bool:
return False return False
def _convert_chunk_to_message_chunk(
chunk: Mapping[str, Any], default_class: type[BaseMessageChunk]
) -> BaseMessageChunk:
choice = chunk["choices"][0]
_dict = choice["delta"]
role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "")
additional_kwargs: dict = {}
tool_call_chunks: list[ToolCallChunk] = []
if _dict.get("function_call"):
function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None:
function_call["name"] = ""
additional_kwargs["function_call"] = function_call
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
for rtc in raw_tool_calls:
try:
tool_call_chunks.append(
create_tool_call_chunk(
name=rtc["function"].get("name"),
args=rtc["function"].get("arguments"),
id=rtc.get("id"),
index=rtc.get("index"),
)
)
except KeyError:
pass
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
if usage := chunk.get("usage"):
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
usage_metadata = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": usage.get("total_tokens", input_tokens + output_tokens),
}
else:
usage_metadata = None
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
usage_metadata=usage_metadata, # type: ignore[arg-type]
)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"])
elif role == "tool" or default_class == ToolMessageChunk:
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content) # type: ignore
def _is_huggingface_textgen_inference(llm: Any) -> bool: def _is_huggingface_textgen_inference(llm: Any) -> bool:
try: try:
from langchain_community.llms.huggingface_text_gen_inference import ( # type: ignore[import-not-found] from langchain_community.llms.huggingface_text_gen_inference import (
HuggingFaceTextGenInference, HuggingFaceTextGenInference, # type: ignore[import-not-found]
) )
return isinstance(llm, HuggingFaceTextGenInference) return isinstance(llm, HuggingFaceTextGenInference)
@ -172,11 +331,11 @@ class ChatHuggingFace(BaseChatModel):
'HuggingFacePipeline' LLM to be used. 'HuggingFacePipeline' LLM to be used.
Key init args client params: Key init args client params:
custom_get_token_ids: Optional[Callable[[str], List[int]]] custom_get_token_ids: Optional[Callable[[str], list[int]]]
Optional encoder to use for counting tokens. Optional encoder to use for counting tokens.
metadata: Optional[Dict[str, Any]] metadata: Optional[dict[str, Any]]
Metadata to add to the run trace. Metadata to add to the run trace.
tags: Optional[List[str]] tags: Optional[list[str]]
Tags to add to the run trace. Tags to add to the run trace.
tokenizer: Any tokenizer: Any
verbose: bool verbose: bool
@ -307,24 +466,43 @@ class ChatHuggingFace(BaseChatModel):
llm: Any llm: Any
"""LLM, must be of type HuggingFaceTextGenInference, HuggingFaceEndpoint, """LLM, must be of type HuggingFaceTextGenInference, HuggingFaceEndpoint,
HuggingFaceHub, or HuggingFacePipeline.""" HuggingFaceHub, or HuggingFacePipeline."""
# TODO: Is system_message used anywhere?
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
tokenizer: Any = None tokenizer: Any = None
"""Tokenizer for the model. Only used for HuggingFacePipeline."""
model_id: Optional[str] = None model_id: Optional[str] = None
"""Model ID for the model. Only used for HuggingFaceEndpoint."""
temperature: Optional[float] = None
"""What sampling temperature to use."""
stop: Optional[Union[str, list[str]]] = Field(default=None, alias="stop_sequences")
"""Default stop sequences."""
presence_penalty: Optional[float] = None
"""Penalizes repeated tokens."""
frequency_penalty: Optional[float] = None
"""Penalizes repeated tokens according to frequency."""
seed: Optional[int] = None
"""Seed for generation"""
logprobs: Optional[bool] = None
"""Whether to return logprobs."""
top_logprobs: Optional[int] = None
"""Number of most likely tokens to return at each token position, each with
an associated log probability. `logprobs` must be set to true
if this parameter is used."""
logit_bias: Optional[dict[int, int]] = None
"""Modify the likelihood of specified tokens appearing in the completion."""
streaming: bool = False
"""Whether to stream the results or not."""
n: Optional[int] = None
"""Number of chat completions to generate for each prompt."""
top_p: Optional[float] = None
"""Total probability mass of tokens to consider at each step."""
max_tokens: Optional[int] = None
"""Maximum number of tokens to generate."""
model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
def __init__(self, **kwargs: Any): def __init__(self, **kwargs: Any):
super().__init__(**kwargs) super().__init__(**kwargs)
from transformers import AutoTokenizer # type: ignore[import]
self._resolve_model_id() self._resolve_model_id()
self.tokenizer = (
AutoTokenizer.from_pretrained(self.model_id)
if self.tokenizer is None
else self.tokenizer
)
@model_validator(mode="after") @model_validator(mode="after")
def validate_llm(self) -> Self: def validate_llm(self) -> Self:
if ( if (
@ -340,17 +518,30 @@ class ChatHuggingFace(BaseChatModel):
) )
return self return self
def _create_chat_result(self, response: TGI_RESPONSE) -> ChatResult: def _create_chat_result(self, response: dict) -> ChatResult:
generations = [] generations = []
finish_reason = response.choices[0].finish_reason token_usage = response.get("usage", {})
gen = ChatGeneration( for res in response["choices"]:
message=_convert_TGI_message_to_LC_message(response.choices[0].message), message = _convert_dict_to_message(res["message"])
generation_info={"finish_reason": finish_reason}, if token_usage and isinstance(message, AIMessage):
) message.usage_metadata = {
generations.append(gen) "input_tokens": token_usage.get("prompt_tokens", 0),
token_usage = response.usage "output_tokens": token_usage.get("completion_tokens", 0),
model_object = self.llm.inference_server_url "total_tokens": token_usage.get("total_tokens", 0),
llm_output = {"token_usage": token_usage, "model": model_object} }
generation_info = dict(finish_reason=res.get("finish_reason"))
if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"]
gen = ChatGeneration(
message=message,
generation_info=generation_info,
)
generations.append(gen)
llm_output = {
"token_usage": token_usage,
"model_name": self.model_id,
"system_fingerprint": response.get("system_fingerprint", ""),
}
return ChatResult(generations=generations, llm_output=llm_output) return ChatResult(generations=generations, llm_output=llm_output)
def _generate( def _generate(
@ -358,18 +549,38 @@ class ChatHuggingFace(BaseChatModel):
messages: list[BaseMessage], messages: list[BaseMessage],
stop: Optional[list[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if _is_huggingface_textgen_inference(self.llm): if _is_huggingface_textgen_inference(self.llm):
message_dicts = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
answer = self.llm.client.chat(messages=message_dicts, **kwargs) answer = self.llm.client.chat(messages=message_dicts, **kwargs)
return self._create_chat_result(answer) return self._create_chat_result(answer)
elif _is_huggingface_endpoint(self.llm): elif _is_huggingface_endpoint(self.llm):
message_dicts = self._create_message_dicts(messages, stop) if should_stream:
answer = self.llm.client.chat_completion(messages=message_dicts, **kwargs) stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {
"stop": stop,
**params,
**({"stream": stream} if stream is not None else {}),
**kwargs,
}
answer = self.llm.client.chat_completion(messages=message_dicts, **params)
return self._create_chat_result(answer) return self._create_chat_result(answer)
else: else:
llm_input = self._to_chat_prompt(messages) llm_input = self._to_chat_prompt(messages)
if should_stream:
stream_iter = self.llm._stream(
llm_input, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
llm_result = self.llm._generate( llm_result = self.llm._generate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
) )
@ -380,12 +591,36 @@ class ChatHuggingFace(BaseChatModel):
messages: list[BaseMessage], messages: list[BaseMessage],
stop: Optional[list[str]] = None, stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
if _is_huggingface_textgen_inference(self.llm): if _is_huggingface_textgen_inference(self.llm):
message_dicts = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
answer = await self.llm.async_client.chat(messages=message_dicts, **kwargs) answer = await self.llm.async_client.chat(messages=message_dicts, **kwargs)
return self._create_chat_result(answer) return self._create_chat_result(answer)
elif _is_huggingface_endpoint(self.llm):
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {
**params,
**({"stream": stream} if stream is not None else {}),
**kwargs,
}
answer = await self.llm.async_client.chat_completion(
messages=message_dicts, **params
)
return self._create_chat_result(answer)
elif _is_huggingface_pipeline(self.llm):
raise NotImplementedError(
"async generation is not supported with HuggingFacePipeline"
)
else: else:
llm_input = self._to_chat_prompt(messages) llm_input = self._to_chat_prompt(messages)
llm_result = await self.llm._agenerate( llm_result = await self.llm._agenerate(
@ -393,6 +628,93 @@ class ChatHuggingFace(BaseChatModel):
) )
return self._to_chat_result(llm_result) return self._to_chat_result(llm_result)
def _stream(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
if _is_huggingface_endpoint(self.llm):
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
for chunk in self.llm.client.chat_completion(
messages=message_dicts, **params
):
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
message_chunk = _convert_chunk_to_message_chunk(
chunk, default_chunk_class
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
generation_info["model_name"] = self.model_id
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = message_chunk.__class__
generation_chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
)
yield generation_chunk
else:
llm_input = self._to_chat_prompt(messages)
stream_iter = self.llm._stream(
llm_input, stop=stop, run_manager=run_manager, **kwargs
)
for chunk in stream_iter: # chunk is a GenerationChunk
chat_chunk = ChatGenerationChunk(
message=AIMessageChunk(content=chunk.text),
generation_info=chunk.generation_info,
)
yield chat_chunk
async def _astream(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
async for chunk in await self.llm.async_client.chat_completion(
messages=message_dicts, **params
):
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
generation_info["model_name"] = self.model_id
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = message_chunk.__class__
generation_chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
if run_manager:
await run_manager.on_llm_new_token(
token=generation_chunk.text,
chunk=generation_chunk,
logprobs=logprobs,
)
yield generation_chunk
def _to_chat_prompt( def _to_chat_prompt(
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],
@ -451,8 +773,18 @@ class ChatHuggingFace(BaseChatModel):
elif _is_huggingface_textgen_inference(self.llm): elif _is_huggingface_textgen_inference(self.llm):
endpoint_url: Optional[str] = self.llm.inference_server_url endpoint_url: Optional[str] = self.llm.inference_server_url
elif _is_huggingface_pipeline(self.llm): elif _is_huggingface_pipeline(self.llm):
from transformers import AutoTokenizer # type: ignore[import]
self.tokenizer = (
AutoTokenizer.from_pretrained(self.model_id)
if self.tokenizer is None
else self.tokenizer
)
self.model_id = self.llm.model_id self.model_id = self.llm.model_id
return return
elif _is_huggingface_endpoint(self.llm):
self.model_id = self.llm.repo_id or self.llm.model
return
else: else:
endpoint_url = self.llm.endpoint_url endpoint_url = self.llm.endpoint_url
available_endpoints = list_inference_endpoints("*") available_endpoints = list_inference_endpoints("*")
@ -525,11 +857,153 @@ class ChatHuggingFace(BaseChatModel):
kwargs["tool_choice"] = tool_choice kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs) return super().bind(tools=formatted_tools, **kwargs)
def with_structured_output(
self,
schema: Optional[Union[dict, type[BaseModel]]] = None,
*,
method: Literal[
"function_calling", "json_mode", "json_schema"
] = "function_calling",
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema:
The output schema. Can be passed in as:
- an OpenAI function/tool schema,
- a JSON Schema,
- a typedDict class (support added in 0.1.7),
Pydantic class is currently supported.
method: The method for steering model generation, one of:
- "function_calling": uses tool-calling features.
- "json_schema": uses dedicated structured output features.
- "json_mode": uses JSON mode.
include_raw:
If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
then both the raw model response (a BaseMessage) and the parsed model
response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
Returns:
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs
an instance of ``schema`` (i.e., a Pydantic object).
Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
If ``include_raw`` is True, then Runnable outputs a dict with keys:
- ``"raw"``: BaseMessage
- ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
- ``"parsing_error"``: Optional[BaseException]
""" # noqa: E501
_ = kwargs.pop("strict", None)
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
if method == "function_calling":
if schema is None:
raise ValueError(
"schema must be specified when method is 'function_calling'. "
"Received None."
)
formatted_tool = convert_to_openai_tool(schema)
tool_name = formatted_tool["function"]["name"]
llm = self.bind_tools(
[schema],
tool_choice=tool_name,
ls_structured_output_format={
"kwargs": {"method": "function_calling"},
"schema": formatted_tool,
},
)
if is_pydantic_schema:
raise NotImplementedError(
"Pydantic schema is not supported for function calling"
)
else:
output_parser: Union[JsonOutputKeyToolsParser, JsonOutputParser] = (
JsonOutputKeyToolsParser(key_name=tool_name, first_tool_only=True)
)
elif method == "json_schema":
if schema is None:
raise ValueError(
"schema must be specified when method is 'json_schema'. "
"Received None."
)
formatted_schema = convert_to_json_schema(schema)
llm = self.bind(
response_format={"type": "json_object", "schema": formatted_schema},
ls_structured_output_format={
"kwargs": {"method": "json_schema"},
"schema": schema,
},
)
output_parser: Union[ # type: ignore[no-redef]
JsonOutputKeyToolsParser, JsonOutputParser
] = JsonOutputParser() # type: ignore[arg-type]
elif method == "json_mode":
llm = self.bind(
response_format={"type": "json_object"},
ls_structured_output_format={
"kwargs": {"method": "json_mode"},
"schema": schema,
},
)
output_parser: Union[ # type: ignore[no-redef]
JsonOutputKeyToolsParser, JsonOutputParser
] = JsonOutputParser() # type: ignore[arg-type]
else:
raise ValueError(
f"Unrecognized method argument. Expected one of 'function_calling' or "
f"'json_mode'. Received: '{method}'"
)
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
def _create_message_dicts( def _create_message_dicts(
self, messages: list[BaseMessage], stop: Optional[list[str]] self, messages: list[BaseMessage], stop: Optional[list[str]]
) -> list[dict[Any, Any]]: ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
message_dicts = [_convert_message_to_chat_message(m) for m in messages] params = self._default_params
return message_dicts if stop is not None:
params["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params
@property
def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling Hugging Face
Inference Providers API."""
params = {
"model": self.model_id,
"stream": self.streaming,
"n": self.n,
"temperature": self.temperature,
"stop": self.stop,
**(self.model_kwargs if self.model_kwargs else {}),
}
if self.max_tokens is not None:
params["max_tokens"] = self.max_tokens
return params
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:

View File

@ -1,5 +1,4 @@
import inspect import inspect
import json # type: ignore[import-not-found]
import logging import logging
import os import os
from collections.abc import AsyncIterator, Iterator, Mapping from collections.abc import AsyncIterator, Iterator, Mapping
@ -27,7 +26,7 @@ VALID_TASKS = (
class HuggingFaceEndpoint(LLM): class HuggingFaceEndpoint(LLM):
""" """
HuggingFace Endpoint. Hugging Face Endpoint. This works with any model that supports text generation (i.e. text completion) task.
To use this class, you should have installed the ``huggingface_hub`` package, and To use this class, you should have installed the ``huggingface_hub`` package, and
the environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, the environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token,
@ -67,6 +66,15 @@ class HuggingFaceEndpoint(LLM):
) )
print(llm.invoke("What is Deep Learning?")) print(llm.invoke("What is Deep Learning?"))
# Basic Example (no streaming) with Mistral-Nemo-Base-2407 model using a third-party provider (Novita).
llm = HuggingFaceEndpoint(
repo_id="mistralai/Mistral-Nemo-Base-2407",
provider="novita",
max_new_tokens=100,
do_sample=False,
huggingfacehub_api_token="my-api-key"
)
print(llm.invoke("What is Deep Learning?"))
""" # noqa: E501 """ # noqa: E501
endpoint_url: Optional[str] = None endpoint_url: Optional[str] = None
@ -74,6 +82,11 @@ class HuggingFaceEndpoint(LLM):
should be pass as env variable in `HF_INFERENCE_ENDPOINT`""" should be pass as env variable in `HF_INFERENCE_ENDPOINT`"""
repo_id: Optional[str] = None repo_id: Optional[str] = None
"""Repo to use. If endpoint_url is not specified then this needs to given""" """Repo to use. If endpoint_url is not specified then this needs to given"""
provider: Optional[str] = None
"""Name of the provider to use for inference with the model specified in `repo_id`.
e.g. "cerebras". if not specified, Defaults to "auto" i.e. the first of the
providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
available providers can be found in the [huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/guides/inference#supported-providers-and-tasks)."""
huggingfacehub_api_token: Optional[str] = Field( huggingfacehub_api_token: Optional[str] = Field(
default_factory=from_env("HUGGINGFACEHUB_API_TOKEN", default=None) default_factory=from_env("HUGGINGFACEHUB_API_TOKEN", default=None)
) )
@ -120,8 +133,7 @@ class HuggingFaceEndpoint(LLM):
client: Any = None #: :meta private: client: Any = None #: :meta private:
async_client: Any = None #: :meta private: async_client: Any = None #: :meta private:
task: Optional[str] = None task: Optional[str] = None
"""Task to call the model with. """Task to call the model with. Should be a task that returns `generated_text`."""
Should be a task that returns `generated_text` or `summary_text`."""
model_config = ConfigDict( model_config = ConfigDict(
extra="forbid", extra="forbid",
@ -190,36 +202,22 @@ class HuggingFaceEndpoint(LLM):
@model_validator(mode="after") @model_validator(mode="after")
def validate_environment(self) -> Self: def validate_environment(self) -> Self:
"""Validate that package is installed and that the API token is valid.""" """Validate that package is installed and that the API token is valid."""
try:
from huggingface_hub import login # type: ignore[import]
except ImportError:
raise ImportError(
"Could not import huggingface_hub python package. "
"Please install it with `pip install huggingface_hub`."
)
huggingfacehub_api_token = self.huggingfacehub_api_token or os.getenv( huggingfacehub_api_token = self.huggingfacehub_api_token or os.getenv(
"HF_TOKEN" "HF_TOKEN"
) )
if huggingfacehub_api_token is not None: from huggingface_hub import ( # type: ignore[import]
try: AsyncInferenceClient, # type: ignore[import]
login(token=huggingfacehub_api_token) InferenceClient, # type: ignore[import]
except Exception as e: )
raise ValueError(
"Could not authenticate with huggingface_hub. "
"Please check your API token."
) from e
from huggingface_hub import AsyncInferenceClient, InferenceClient
# Instantiate clients with supported kwargs # Instantiate clients with supported kwargs
sync_supported_kwargs = set(inspect.signature(InferenceClient).parameters) sync_supported_kwargs = set(inspect.signature(InferenceClient).parameters)
self.client = InferenceClient( self.client = InferenceClient(
model=self.model, model=self.model,
timeout=self.timeout, timeout=self.timeout,
token=huggingfacehub_api_token, api_key=huggingfacehub_api_token,
provider=self.provider, # type: ignore[arg-type]
**{ **{
key: value key: value
for key, value in self.server_kwargs.items() for key, value in self.server_kwargs.items()
@ -231,14 +229,14 @@ class HuggingFaceEndpoint(LLM):
self.async_client = AsyncInferenceClient( self.async_client = AsyncInferenceClient(
model=self.model, model=self.model,
timeout=self.timeout, timeout=self.timeout,
token=huggingfacehub_api_token, api_key=huggingfacehub_api_token,
provider=self.provider, # type: ignore[arg-type]
**{ **{
key: value key: value
for key, value in self.server_kwargs.items() for key, value in self.server_kwargs.items()
if key in async_supported_kwargs if key in async_supported_kwargs
}, },
) )
ignored_kwargs = ( ignored_kwargs = (
set(self.server_kwargs.keys()) set(self.server_kwargs.keys())
- sync_supported_kwargs - sync_supported_kwargs
@ -264,7 +262,7 @@ class HuggingFaceEndpoint(LLM):
"repetition_penalty": self.repetition_penalty, "repetition_penalty": self.repetition_penalty,
"return_full_text": self.return_full_text, "return_full_text": self.return_full_text,
"truncate": self.truncate, "truncate": self.truncate,
"stop_sequences": self.stop_sequences, "stop": self.stop_sequences,
"seed": self.seed, "seed": self.seed,
"do_sample": self.do_sample, "do_sample": self.do_sample,
"watermark": self.watermark, "watermark": self.watermark,
@ -276,7 +274,11 @@ class HuggingFaceEndpoint(LLM):
"""Get the identifying parameters.""" """Get the identifying parameters."""
_model_kwargs = self.model_kwargs or {} _model_kwargs = self.model_kwargs or {}
return { return {
**{"endpoint_url": self.endpoint_url, "task": self.task}, **{
"endpoint_url": self.endpoint_url,
"task": self.task,
"provider": self.provider,
},
**{"model_kwargs": _model_kwargs}, **{"model_kwargs": _model_kwargs},
} }
@ -289,7 +291,7 @@ class HuggingFaceEndpoint(LLM):
self, runtime_stop: Optional[list[str]], **kwargs: Any self, runtime_stop: Optional[list[str]], **kwargs: Any
) -> dict[str, Any]: ) -> dict[str, Any]:
params = {**self._default_params, **kwargs} params = {**self._default_params, **kwargs}
params["stop_sequences"] = params["stop_sequences"] + (runtime_stop or []) params["stop"] = params["stop"] + (runtime_stop or [])
return params return params
def _call( def _call(
@ -307,19 +309,15 @@ class HuggingFaceEndpoint(LLM):
completion += chunk.text completion += chunk.text
return completion return completion
else: else:
invocation_params["stop"] = invocation_params[ response_text = self.client.text_generation(
"stop_sequences" prompt=prompt,
] # porting 'stop_sequences' into the 'stop' argument model=self.model,
response = self.client.post( **invocation_params,
json={"inputs": prompt, "parameters": invocation_params},
stream=False,
task=self.task,
) )
response_text = json.loads(response.decode())[0]["generated_text"]
# Maybe the generation has stopped at one of the stop sequences: # Maybe the generation has stopped at one of the stop sequences:
# then we remove this stop sequence from the end of the generated text # then we remove this stop sequence from the end of the generated text
for stop_seq in invocation_params["stop_sequences"]: for stop_seq in invocation_params["stop"]:
if response_text[-len(stop_seq) :] == stop_seq: if response_text[-len(stop_seq) :] == stop_seq:
response_text = response_text[: -len(stop_seq)] response_text = response_text[: -len(stop_seq)]
return response_text return response_text
@ -340,17 +338,16 @@ class HuggingFaceEndpoint(LLM):
completion += chunk.text completion += chunk.text
return completion return completion
else: else:
invocation_params["stop"] = invocation_params["stop_sequences"] response_text = await self.async_client.text_generation(
response = await self.async_client.post( prompt=prompt,
json={"inputs": prompt, "parameters": invocation_params}, **invocation_params,
model=self.model,
stream=False, stream=False,
task=self.task,
) )
response_text = json.loads(response.decode())[0]["generated_text"]
# Maybe the generation has stopped at one of the stop sequences: # Maybe the generation has stopped at one of the stop sequences:
# then remove this stop sequence from the end of the generated text # then remove this stop sequence from the end of the generated text
for stop_seq in invocation_params["stop_sequences"]: for stop_seq in invocation_params["stop"]:
if response_text[-len(stop_seq) :] == stop_seq: if response_text[-len(stop_seq) :] == stop_seq:
response_text = response_text[: -len(stop_seq)] response_text = response_text[: -len(stop_seq)]
return response_text return response_text
@ -369,7 +366,7 @@ class HuggingFaceEndpoint(LLM):
): ):
# identify stop sequence in generated text, if any # identify stop sequence in generated text, if any
stop_seq_found: Optional[str] = None stop_seq_found: Optional[str] = None
for stop_seq in invocation_params["stop_sequences"]: for stop_seq in invocation_params["stop"]:
if stop_seq in response: if stop_seq in response:
stop_seq_found = stop_seq stop_seq_found = stop_seq
@ -405,7 +402,7 @@ class HuggingFaceEndpoint(LLM):
): ):
# identify stop sequence in generated text, if any # identify stop sequence in generated text, if any
stop_seq_found: Optional[str] = None stop_seq_found: Optional[str] = None
for stop_seq in invocation_params["stop_sequences"]: for stop_seq in invocation_params["stop"]:
if stop_seq in response: if stop_seq in response:
stop_seq_found = stop_seq stop_seq_found = stop_seq

View File

@ -44,7 +44,6 @@ typing = ["mypy<2.0,>=1.10", "langchain-core"]
[tool.uv.sources] [tool.uv.sources]
langchain-core = { path = "../../core", editable = true } langchain-core = { path = "../../core", editable = true }
langchain-tests = { path = "../../standard-tests", editable = true } langchain-tests = { path = "../../standard-tests", editable = true }
langchain-community = { path = "../../community", editable = true }
[tool.mypy] [tool.mypy]
disallow_untyped_defs = "True" disallow_untyped_defs = "True"

View File

@ -6,7 +6,9 @@ from langchain_huggingface.llms import HuggingFacePipeline
def test_huggingface_pipeline_streaming() -> None: def test_huggingface_pipeline_streaming() -> None:
"""Test streaming tokens from huggingface_pipeline.""" """Test streaming tokens from huggingface_pipeline."""
llm = HuggingFacePipeline.from_model_id( llm = HuggingFacePipeline.from_model_id(
model_id="gpt2", task="text-generation", pipeline_kwargs={"max_new_tokens": 10} model_id="openai-community/gpt2",
task="text-generation",
pipeline_kwargs={"max_new_tokens": 10},
) )
generator = llm.stream("Q: How do you say 'hello' in German? A:'", stop=["."]) generator = llm.stream("Q: How do you say 'hello' in German? A:'", stop=["."])
stream_results_string = "" stream_results_string = ""
@ -15,4 +17,4 @@ def test_huggingface_pipeline_streaming() -> None:
for chunk in generator: for chunk in generator:
assert isinstance(chunk, str) assert isinstance(chunk, str)
stream_results_string = chunk stream_results_string = chunk
assert len(stream_results_string.strip()) > 1 assert len(stream_results_string.strip()) > 0

View File

@ -15,70 +15,39 @@ class TestHuggingFaceEndpoint(ChatModelIntegrationTests):
@property @property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
return {} llm = HuggingFaceEndpoint( # type: ignore[call-arg]
repo_id="Qwen/Qwen2.5-72B-Instruct",
task="conversational",
provider="fireworks-ai",
temperature=0,
)
return {"llm": llm}
@pytest.fixture @pytest.fixture
def model(self) -> BaseChatModel: def model(self) -> BaseChatModel:
llm = HuggingFaceEndpoint( # type: ignore[call-arg] return self.chat_model_class(**self.chat_model_params) # type: ignore[call-arg]
repo_id="HuggingFaceH4/zephyr-7b-beta",
task="text-generation",
max_new_tokens=512,
do_sample=False,
repetition_penalty=1.03,
)
return self.chat_model_class(llm=llm) # type: ignore[call-arg]
@pytest.mark.xfail(reason=("Not implemented")) @pytest.mark.xfail(
def test_stream(self, model: BaseChatModel) -> None: reason=("Overrding, testing only typed dict and json schema structured output")
super().test_stream(model) )
@pytest.mark.parametrize("schema_type", ["typeddict", "json_schema"])
@pytest.mark.xfail(reason=("Not implemented"))
async def test_astream(self, model: BaseChatModel) -> None:
await super().test_astream(model)
@pytest.mark.xfail(reason=("Not implemented"))
def test_usage_metadata(self, model: BaseChatModel) -> None:
super().test_usage_metadata(model)
@pytest.mark.xfail(reason=("Not implemented"))
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
super().test_usage_metadata_streaming(model)
@pytest.mark.xfail(reason=("Not implemented"))
def test_stop_sequence(self, model: BaseChatModel) -> None:
super().test_stop_sequence(model)
@pytest.mark.xfail(reason=("Not implemented"))
def test_tool_calling(self, model: BaseChatModel) -> None:
super().test_tool_calling(model)
@pytest.mark.xfail(reason=("Not implemented"))
async def test_tool_calling_async(self, model: BaseChatModel) -> None:
await super().test_tool_calling_async(model)
@pytest.mark.xfail(reason=("Not implemented"))
def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None:
super().test_tool_calling_with_no_arguments(model)
@pytest.mark.xfail(reason=("Not implemented"))
def test_bind_runnables_as_tools(self, model: BaseChatModel) -> None:
super().test_bind_runnables_as_tools(model)
@pytest.mark.xfail(reason=("Not implemented"))
def test_structured_output(self, model: BaseChatModel, schema_type: str) -> None: def test_structured_output(self, model: BaseChatModel, schema_type: str) -> None:
super().test_structured_output(model, schema_type) super().test_structured_output(model, schema_type)
@pytest.mark.xfail(reason=("Not implemented")) @pytest.mark.xfail(
reason=("Overrding, testing only typed dict and json schema structured output")
)
@pytest.mark.parametrize("schema_type", ["typeddict", "json_schema"])
async def test_structured_output_async( async def test_structured_output_async(
self, model: BaseChatModel, schema_type: str self, model: BaseChatModel, schema_type: str
) -> None: # type: ignore[override] ) -> None: # type: ignore[override]
super().test_structured_output(model, schema_type) super().test_structured_output(model, schema_type)
@pytest.mark.xfail(reason=("Not implemented")) @pytest.mark.xfail(reason=("Pydantic structured output is not supported"))
def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None: def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None:
super().test_structured_output_pydantic_2_v1(model) super().test_structured_output_pydantic_2_v1(model)
@pytest.mark.xfail(reason=("Not implemented")) @pytest.mark.xfail(reason=("Pydantic structured output is not supported"))
def test_structured_output_optional_param(self, model: BaseChatModel) -> None: def test_structured_output_optional_param(self, model: BaseChatModel) -> None:
super().test_structured_output_optional_param(model) super().test_structured_output_optional_param(model)
@ -95,3 +64,7 @@ class TestHuggingFaceEndpoint(ChatModelIntegrationTests):
self, model: BaseChatModel, my_adder_tool: BaseTool self, model: BaseChatModel, my_adder_tool: BaseTool
) -> None: ) -> None:
super().test_structured_few_shot_examples(model, my_adder_tool=my_adder_tool) super().test_structured_few_shot_examples(model, my_adder_tool=my_adder_tool)
@property
def has_tool_choice(self) -> bool:
return False

View File

@ -1,11 +1,11 @@
from typing import Any # type: ignore[import-not-found] from typing import Any
from unittest.mock import MagicMock, Mock, patch from unittest.mock import MagicMock, Mock, patch
import pytest # type: ignore[import-not-found] import pytest # type: ignore[import-not-found]
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
BaseMessage, BaseMessage,
ChatMessage, FunctionMessage,
HumanMessage, HumanMessage,
SystemMessage, SystemMessage,
) )
@ -13,92 +13,10 @@ from langchain_core.outputs import ChatResult
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_huggingface.chat_models import ( # type: ignore[import] from langchain_huggingface.chat_models import ( # type: ignore[import]
TGI_MESSAGE,
ChatHuggingFace, ChatHuggingFace,
_convert_message_to_chat_message, _convert_dict_to_message,
_convert_TGI_message_to_LC_message,
) )
from langchain_huggingface.llms.huggingface_endpoint import ( from langchain_huggingface.llms import HuggingFaceEndpoint
HuggingFaceEndpoint,
)
@pytest.mark.parametrize(
("message", "expected"),
[
(
SystemMessage(content="Hello"),
dict(role="system", content="Hello"),
),
(
HumanMessage(content="Hello"),
dict(role="user", content="Hello"),
),
(
AIMessage(content="Hello"),
dict(role="assistant", content="Hello", tool_calls=None),
),
(
ChatMessage(role="assistant", content="Hello"),
dict(role="assistant", content="Hello"),
),
],
)
def test_convert_message_to_chat_message(
message: BaseMessage, expected: dict[str, str]
) -> None:
result = _convert_message_to_chat_message(message)
assert result == expected
@pytest.mark.parametrize(
("tgi_message", "expected"),
[
(
TGI_MESSAGE(role="assistant", content="Hello", tool_calls=[]),
AIMessage(content="Hello"),
),
(
TGI_MESSAGE(role="assistant", content="", tool_calls=[]),
AIMessage(content=""),
),
(
TGI_MESSAGE(
role="assistant",
content="",
tool_calls=[{"function": {"arguments": "function string"}}],
),
AIMessage(
content="",
additional_kwargs={
"tool_calls": [{"function": {"arguments": '"function string"'}}]
},
),
),
(
TGI_MESSAGE(
role="assistant",
content="",
tool_calls=[
{"function": {"arguments": {"answer": "function's string"}}}
],
),
AIMessage(
content="",
additional_kwargs={
"tool_calls": [
{"function": {"arguments": '{"answer": "function\'s string"}'}}
]
},
),
),
],
)
def test_convert_TGI_message_to_LC_message(
tgi_message: TGI_MESSAGE, expected: BaseMessage
) -> None:
result = _convert_TGI_message_to_LC_message(tgi_message)
assert result == expected
@pytest.fixture @pytest.fixture
@ -118,16 +36,15 @@ def chat_hugging_face(mock_resolve_id: Any, mock_llm: Any) -> ChatHuggingFace:
def test_create_chat_result(chat_hugging_face: Any) -> None: def test_create_chat_result(chat_hugging_face: Any) -> None:
mock_response = MagicMock() mock_response = {
mock_response.choices = [ "choices": [
MagicMock( {
message=TGI_MESSAGE( "message": {"role": "assistant", "content": "test message"},
role="assistant", content="test message", tool_calls=[] "finish_reason": "test finish reason",
), }
finish_reason="test finish reason", ],
) "usage": {"tokens": 420},
] }
mock_response.usage = {"tokens": 420}
result = chat_hugging_face._create_chat_result(mock_response) result = chat_hugging_face._create_chat_result(mock_response)
assert isinstance(result, ChatResult) assert isinstance(result, ChatResult)
@ -136,7 +53,7 @@ def test_create_chat_result(chat_hugging_face: Any) -> None:
result.generations[0].generation_info["finish_reason"] == "test finish reason" # type: ignore[index] result.generations[0].generation_info["finish_reason"] == "test finish reason" # type: ignore[index]
) )
assert result.llm_output["token_usage"]["tokens"] == 420 # type: ignore[index] assert result.llm_output["token_usage"]["tokens"] == 420 # type: ignore[index]
assert result.llm_output["model"] == chat_hugging_face.llm.inference_server_url # type: ignore[index] assert result.llm_output["model_name"] == chat_hugging_face.model_id # type: ignore[index]
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -207,6 +124,39 @@ def test_to_chatml_format_with_invalid_type(chat_hugging_face: Any) -> None:
assert "Unknown message type:" in str(e.value) assert "Unknown message type:" in str(e.value)
@pytest.mark.parametrize(
("msg_dict", "expected_type", "expected_content"),
[
(
{"role": "system", "content": "You are helpful"},
SystemMessage,
"You are helpful",
),
(
{"role": "user", "content": "Hello there"},
HumanMessage,
"Hello there",
),
(
{"role": "assistant", "content": "How can I help?"},
AIMessage,
"How can I help?",
),
(
{"role": "function", "content": "result", "name": "get_time"},
FunctionMessage,
"result",
),
],
)
def test_convert_dict_to_message(
msg_dict: dict[str, Any], expected_type: type, expected_content: str
) -> None:
result = _convert_dict_to_message(msg_dict)
assert isinstance(result, expected_type)
assert result.content == expected_content
def tool_mock() -> dict: def tool_mock() -> dict:
return {"function": {"name": "test_tool"}} return {"function": {"name": "test_tool"}}

View File

@ -1,5 +1,4 @@
version = 1 version = 1
revision = 1
requires-python = ">=3.9" requires-python = ">=3.9"
resolution-markers = [ resolution-markers = [
"python_full_version >= '3.13'", "python_full_version >= '3.13'",
@ -857,7 +856,7 @@ wheels = [
[[package]] [[package]]
name = "langchain" name = "langchain"
version = "0.3.24" version = "0.3.24"
source = { editable = "../../langchain" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "async-timeout", marker = "python_full_version < '3.11'" }, { name = "async-timeout", marker = "python_full_version < '3.11'" },
{ name = "langchain-core" }, { name = "langchain-core" },
@ -868,108 +867,15 @@ dependencies = [
{ name = "requests" }, { name = "requests" },
{ name = "sqlalchemy" }, { name = "sqlalchemy" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/a3/8f/db961066a65e678036886c73234827c56547fed2e06fd1b425767e4dc059/langchain-0.3.24.tar.gz", hash = "sha256:caf1bacdabbea429bc79b58b118c06c3386107d92812e15922072b91745f070f", size = 10224882 }
[package.metadata] wheels = [
requires-dist = [ { url = "https://files.pythonhosted.org/packages/ba/83/77392f0a6a560e471075b125656b392d3b889be65ee8e93a5c31aa7a62bb/langchain-0.3.24-py3-none-any.whl", hash = "sha256:596c5444716644ddd0cd819fb2bc9d0fd4221503b219fdfb5016edcfaa7da8ef", size = 1010778 },
{ name = "async-timeout", marker = "python_full_version < '3.11'", specifier = ">=4.0.0,<5.0.0" },
{ name = "langchain-anthropic", marker = "extra == 'anthropic'" },
{ name = "langchain-aws", marker = "extra == 'aws'" },
{ name = "langchain-azure-ai", marker = "extra == 'azure-ai'" },
{ name = "langchain-cohere", marker = "extra == 'cohere'" },
{ name = "langchain-community", marker = "extra == 'community'" },
{ name = "langchain-core", editable = "../../core" },
{ name = "langchain-deepseek", marker = "extra == 'deepseek'" },
{ name = "langchain-fireworks", marker = "extra == 'fireworks'" },
{ name = "langchain-google-genai", marker = "extra == 'google-genai'" },
{ name = "langchain-google-vertexai", marker = "extra == 'google-vertexai'" },
{ name = "langchain-groq", marker = "extra == 'groq'" },
{ name = "langchain-huggingface", marker = "extra == 'huggingface'" },
{ name = "langchain-mistralai", marker = "extra == 'mistralai'" },
{ name = "langchain-ollama", marker = "extra == 'ollama'" },
{ name = "langchain-openai", marker = "extra == 'openai'", editable = "../openai" },
{ name = "langchain-perplexity", marker = "extra == 'perplexity'" },
{ name = "langchain-text-splitters", editable = "../../text-splitters" },
{ name = "langchain-together", marker = "extra == 'together'" },
{ name = "langchain-xai", marker = "extra == 'xai'" },
{ name = "langsmith", specifier = ">=0.1.17,<0.4" },
{ name = "pydantic", specifier = ">=2.7.4,<3.0.0" },
{ name = "pyyaml", specifier = ">=5.3" },
{ name = "requests", specifier = ">=2,<3" },
{ name = "sqlalchemy", specifier = ">=1.4,<3" },
]
provides-extras = ["community", "anthropic", "openai", "azure-ai", "cohere", "google-vertexai", "google-genai", "fireworks", "ollama", "together", "mistralai", "huggingface", "groq", "aws", "deepseek", "xai", "perplexity"]
[package.metadata.requires-dev]
codespell = [{ name = "codespell", specifier = ">=2.2.0,<3.0.0" }]
dev = [
{ name = "jupyter", specifier = ">=1.0.0,<2.0.0" },
{ name = "langchain-core", editable = "../../core" },
{ name = "langchain-text-splitters", editable = "../../text-splitters" },
{ name = "playwright", specifier = ">=1.28.0,<2.0.0" },
{ name = "setuptools", specifier = ">=67.6.1,<68.0.0" },
]
lint = [
{ name = "cffi", marker = "python_full_version < '3.10'", specifier = "<1.17.1" },
{ name = "cffi", marker = "python_full_version >= '3.10'" },
{ name = "ruff", specifier = ">=0.9.2,<1.0.0" },
]
test = [
{ name = "blockbuster", specifier = ">=1.5.18,<1.6" },
{ name = "cffi", marker = "python_full_version < '3.10'", specifier = "<1.17.1" },
{ name = "cffi", marker = "python_full_version >= '3.10'" },
{ name = "duckdb-engine", specifier = ">=0.9.2,<1.0.0" },
{ name = "freezegun", specifier = ">=1.2.2,<2.0.0" },
{ name = "langchain-core", editable = "../../core" },
{ name = "langchain-openai", editable = "../openai" },
{ name = "langchain-tests", editable = "../../standard-tests" },
{ name = "langchain-text-splitters", editable = "../../text-splitters" },
{ name = "lark", specifier = ">=1.1.5,<2.0.0" },
{ name = "numpy", marker = "python_full_version < '3.13'", specifier = ">=1.26.4" },
{ name = "numpy", marker = "python_full_version >= '3.13'", specifier = ">=2.1.0" },
{ name = "packaging", specifier = ">=24.2" },
{ name = "pandas", specifier = ">=2.0.0,<3.0.0" },
{ name = "pytest", specifier = ">=8,<9" },
{ name = "pytest-asyncio", specifier = ">=0.23.2,<1.0.0" },
{ name = "pytest-cov", specifier = ">=4.0.0,<5.0.0" },
{ name = "pytest-dotenv", specifier = ">=0.5.2,<1.0.0" },
{ name = "pytest-mock", specifier = ">=3.10.0,<4.0.0" },
{ name = "pytest-socket", specifier = ">=0.6.0,<1.0.0" },
{ name = "pytest-watcher", specifier = ">=0.2.6,<1.0.0" },
{ name = "pytest-xdist", specifier = ">=3.6.1,<4.0.0" },
{ name = "requests-mock", specifier = ">=1.11.0,<2.0.0" },
{ name = "responses", specifier = ">=0.22.0,<1.0.0" },
{ name = "syrupy", specifier = ">=4.0.2,<5.0.0" },
{ name = "toml", specifier = ">=0.10.2" },
]
test-integration = [
{ name = "cassio", specifier = ">=0.1.0,<1.0.0" },
{ name = "langchain-core", editable = "../../core" },
{ name = "langchain-text-splitters", editable = "../../text-splitters" },
{ name = "langchainhub", specifier = ">=0.1.16,<1.0.0" },
{ name = "pytest-vcr", specifier = ">=1.0.2,<2.0.0" },
{ name = "python-dotenv", specifier = ">=1.0.0,<2.0.0" },
{ name = "urllib3", marker = "python_full_version < '3.10'", specifier = "<2" },
{ name = "wrapt", specifier = ">=1.15.0,<2.0.0" },
]
typing = [
{ name = "langchain-core", editable = "../../core" },
{ name = "langchain-text-splitters", editable = "../../text-splitters" },
{ name = "mypy", specifier = ">=1.15,<2.0" },
{ name = "mypy-protobuf", specifier = ">=3.0.0,<4.0.0" },
{ name = "numpy", marker = "python_full_version < '3.13'", specifier = ">=1.26.4" },
{ name = "numpy", marker = "python_full_version >= '3.13'", specifier = ">=2.1.0" },
{ name = "types-chardet", specifier = ">=5.0.4.6,<6.0.0.0" },
{ name = "types-pytz", specifier = ">=2023.3.0.0,<2024.0.0.0" },
{ name = "types-pyyaml", specifier = ">=6.0.12.2,<7.0.0.0" },
{ name = "types-redis", specifier = ">=4.3.21.6,<5.0.0.0" },
{ name = "types-requests", specifier = ">=2.28.11.5,<3.0.0.0" },
{ name = "types-toml", specifier = ">=0.10.8.1,<1.0.0.0" },
] ]
[[package]] [[package]]
name = "langchain-community" name = "langchain-community"
version = "0.3.22" version = "0.3.22"
source = { editable = "../../community" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "aiohttp" }, { name = "aiohttp" },
{ name = "dataclasses-json" }, { name = "dataclasses-json" },
@ -985,76 +891,9 @@ dependencies = [
{ name = "sqlalchemy" }, { name = "sqlalchemy" },
{ name = "tenacity" }, { name = "tenacity" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/04/a9/32b4fb08b82b264cba1096d7daa49de808e117046ebf9df4c382e23791db/langchain_community-0.3.22.tar.gz", hash = "sha256:36284687a9f64bc7820c0140beb3b96393f6c74c0b7ad8ba04ac35d673fe0988", size = 33230274 }
[package.metadata] wheels = [
requires-dist = [ { url = "https://files.pythonhosted.org/packages/bb/bb/ebd0f33408f95ebfdb48e2a551c50506c46efc57b836b57c792ccd14290d/langchain_community-0.3.22-py3-none-any.whl", hash = "sha256:02ecdc669408d587b9dda78462dbbe8c27168edd26bb205630d0bc753e7cce6b", size = 2529327 },
{ name = "aiohttp", specifier = ">=3.8.3,<4.0.0" },
{ name = "dataclasses-json", specifier = ">=0.5.7,<0.7" },
{ name = "httpx-sse", specifier = ">=0.4.0,<1.0.0" },
{ name = "langchain", editable = "../../langchain" },
{ name = "langchain-core", editable = "../../core" },
{ name = "langsmith", specifier = ">=0.1.125,<0.4" },
{ name = "numpy", marker = "python_full_version < '3.13'", specifier = ">=1.26.2" },
{ name = "numpy", marker = "python_full_version >= '3.13'", specifier = ">=2.1.0" },
{ name = "pydantic-settings", specifier = ">=2.4.0,<3.0.0" },
{ name = "pyyaml", specifier = ">=5.3" },
{ name = "requests", specifier = ">=2,<3" },
{ name = "sqlalchemy", specifier = ">=1.4,<3" },
{ name = "tenacity", specifier = ">=8.1.0,!=8.4.0,<10" },
]
[package.metadata.requires-dev]
codespell = [{ name = "codespell", specifier = ">=2.2.0,<3.0.0" }]
dev = [
{ name = "jupyter", specifier = ">=1.0.0,<2.0.0" },
{ name = "langchain-core", editable = "../../core" },
{ name = "setuptools", specifier = ">=67.6.1,<68.0.0" },
]
lint = [
{ name = "cffi", marker = "python_full_version < '3.10'", specifier = "<1.17.1" },
{ name = "cffi", marker = "python_full_version >= '3.10'" },
{ name = "ruff", specifier = ">=0.9,<0.10" },
]
test = [
{ name = "blockbuster", specifier = ">=1.5.18,<1.6" },
{ name = "cffi", marker = "python_full_version < '3.10'", specifier = "<1.17.1" },
{ name = "cffi", marker = "python_full_version >= '3.10'" },
{ name = "duckdb-engine", specifier = ">=0.13.6,<1.0.0" },
{ name = "freezegun", specifier = ">=1.2.2,<2.0.0" },
{ name = "langchain", editable = "../../langchain" },
{ name = "langchain-core", editable = "../../core" },
{ name = "langchain-tests", editable = "../../standard-tests" },
{ name = "lark", specifier = ">=1.1.5,<2.0.0" },
{ name = "pandas", specifier = ">=2.0.0,<3.0.0" },
{ name = "pytest", specifier = ">=7.4.4,<8.0.0" },
{ name = "pytest-asyncio", specifier = ">=0.20.3,<1.0.0" },
{ name = "pytest-cov", specifier = ">=4.1.0,<5.0.0" },
{ name = "pytest-dotenv", specifier = ">=0.5.2,<1.0.0" },
{ name = "pytest-mock", specifier = ">=3.10.0,<4.0.0" },
{ name = "pytest-socket", specifier = ">=0.6.0,<1.0.0" },
{ name = "pytest-watcher", specifier = ">=0.2.6,<1.0.0" },
{ name = "pytest-xdist", specifier = ">=3.6.1,<4.0.0" },
{ name = "requests-mock", specifier = ">=1.11.0,<2.0.0" },
{ name = "responses", specifier = ">=0.22.0,<1.0.0" },
{ name = "syrupy", specifier = ">=4.0.2,<5.0.0" },
{ name = "toml", specifier = ">=0.10.2" },
]
test-integration = [
{ name = "pytest-vcr", specifier = ">=1.0.2,<2.0.0" },
{ name = "vcrpy", specifier = ">=6,<7" },
]
typing = [
{ name = "langchain", editable = "../../langchain" },
{ name = "langchain-core", editable = "../../core" },
{ name = "langchain-text-splitters", editable = "../../text-splitters" },
{ name = "mypy", specifier = ">=1.15,<2.0" },
{ name = "mypy-protobuf", specifier = ">=3.0.0,<4.0.0" },
{ name = "types-chardet", specifier = ">=5.0.4.6,<6.0.0.0" },
{ name = "types-pytz", specifier = ">=2023.3.0.0,<2024.0.0.0" },
{ name = "types-pyyaml", specifier = ">=6.0.12.2,<7.0.0.0" },
{ name = "types-redis", specifier = ">=4.3.21.6,<5.0.0.0" },
{ name = "types-requests", specifier = ">=2.28.11.5,<3.0.0.0" },
{ name = "types-toml", specifier = ">=0.10.8.1,<1.0.0.0" },
] ]
[[package]] [[package]]
@ -1172,7 +1011,7 @@ dev = [
] ]
lint = [{ name = "ruff", specifier = ">=0.5,<1.0" }] lint = [{ name = "ruff", specifier = ">=0.5,<1.0" }]
test = [ test = [
{ name = "langchain-community", editable = "../../community" }, { name = "langchain-community" },
{ name = "langchain-core", editable = "../../core" }, { name = "langchain-core", editable = "../../core" },
{ name = "langchain-tests", editable = "../../standard-tests" }, { name = "langchain-tests", editable = "../../standard-tests" },
{ name = "pytest", specifier = ">=7.3.0,<8.0.0" }, { name = "pytest", specifier = ">=7.3.0,<8.0.0" },
@ -1228,45 +1067,13 @@ typing = [
[[package]] [[package]]
name = "langchain-text-splitters" name = "langchain-text-splitters"
version = "0.3.8" version = "0.3.8"
source = { editable = "../../text-splitters" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "langchain-core" }, { name = "langchain-core" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/e7/ac/b4a25c5716bb0103b1515f1f52cc69ffb1035a5a225ee5afe3aed28bf57b/langchain_text_splitters-0.3.8.tar.gz", hash = "sha256:116d4b9f2a22dda357d0b79e30acf005c5518177971c66a9f1ab0edfdb0f912e", size = 42128 }
[package.metadata] wheels = [
requires-dist = [{ name = "langchain-core", editable = "../../core" }] { url = "https://files.pythonhosted.org/packages/8b/a3/3696ff2444658053c01b6b7443e761f28bb71217d82bb89137a978c5f66f/langchain_text_splitters-0.3.8-py3-none-any.whl", hash = "sha256:e75cc0f4ae58dcf07d9f18776400cf8ade27fadd4ff6d264df6278bb302f6f02", size = 32440 },
[package.metadata.requires-dev]
dev = [
{ name = "jupyter", specifier = ">=1.0.0,<2.0.0" },
{ name = "langchain-core", editable = "../../core" },
]
lint = [
{ name = "langchain-core", editable = "../../core" },
{ name = "ruff", specifier = ">=0.9.2,<1.0.0" },
]
test = [
{ name = "freezegun", specifier = ">=1.2.2,<2.0.0" },
{ name = "langchain-core", editable = "../../core" },
{ name = "pytest", specifier = ">=8,<9" },
{ name = "pytest-asyncio", specifier = ">=0.21.1,<1.0.0" },
{ name = "pytest-mock", specifier = ">=3.10.0,<4.0.0" },
{ name = "pytest-socket", specifier = ">=0.7.0,<1.0.0" },
{ name = "pytest-watcher", specifier = ">=0.3.4,<1.0.0" },
{ name = "pytest-xdist", specifier = ">=3.6.1,<4.0.0" },
]
test-integration = [
{ name = "nltk", specifier = ">=3.9.1,<4.0.0" },
{ name = "sentence-transformers", marker = "python_full_version < '3.13'", specifier = ">=2.6.0" },
{ name = "spacy", marker = "python_full_version < '3.10'", specifier = ">=3.0.0,<3.8.4" },
{ name = "spacy", marker = "python_full_version < '3.13'", specifier = ">=3.0.0,<4.0.0" },
{ name = "transformers", specifier = ">=4.47.0,<5.0.0" },
]
typing = [
{ name = "lxml-stubs", specifier = ">=0.5.1,<1.0.0" },
{ name = "mypy", specifier = ">=1.15,<2.0" },
{ name = "tiktoken", specifier = ">=0.8.0,<1.0.0" },
{ name = "types-requests", specifier = ">=2.31.0.20240218,<3.0.0.0" },
] ]
[[package]] [[package]]