partners: (langchain-huggingface) Chat Models - Integrate Hugging Face Inference Providers and remove deprecated code (#30733)

Hi there, I'm Célina from 🤗,
This PR introduces support for Hugging Face's serverless Inference
Providers (documentation
[here](https://huggingface.co/docs/inference-providers/index)), allowing
users to specify different providers for chat completion and text
generation tasks.

This PR also removes the usage of `InferenceClient.post()` method in
`HuggingFaceEndpoint`, in favor of the task-specific `text_generation`
method. `InferenceClient.post()` is deprecated and will be removed in
`huggingface_hub v0.31.0`.

---
## Changes made
- bumped the minimum required version of the `huggingface-hub` package
to ensure compatibility with the latest API usage.
- added a `provider` field to `HuggingFaceEndpoint`, enabling users to
select the inference provider (e.g., 'cerebras', 'together',
'fireworks-ai'). Defaults to `hf-inference` (HF Inference API).
- replaced the deprecated `InferenceClient.post()` call in
`HuggingFaceEndpoint` with the task-specific `text_generation` method
for future-proofing, `post()` will be removed in huggingface-hub
v0.31.0.
- updated the `ChatHuggingFace` component:
    - added async and streaming support.
    - added support for tool calling.
- exposed underlying chat completion parameters for more granular
control.
- Added integration tests for `ChatHuggingFace` and updated the
corresponding unit tests.

  All changes are backward compatible.

---------

Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
célina
2025-04-29 15:53:14 +02:00
committed by GitHub
parent 3072e4610a
commit 868f07f8f4
8 changed files with 699 additions and 504 deletions

View File

@@ -2,14 +2,7 @@ from langchain_huggingface.chat_models.huggingface import ( # type: ignore[impo
TGI_MESSAGE,
TGI_RESPONSE,
ChatHuggingFace,
_convert_message_to_chat_message,
_convert_TGI_message_to_LC_message,
_convert_dict_to_message,
)
__all__ = [
"ChatHuggingFace",
"_convert_message_to_chat_message",
"_convert_TGI_message_to_LC_message",
"TGI_MESSAGE",
"TGI_RESPONSE",
]
__all__ = ["ChatHuggingFace", "_convert_dict_to_message", "TGI_MESSAGE", "TGI_RESPONSE"]

View File

@@ -1,42 +1,65 @@
"""Hugging Face Chat Wrapper."""
import json
from collections.abc import Sequence
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from dataclasses import dataclass
from typing import (
Any,
Callable,
Literal,
Optional,
Union,
cast,
)
from operator import itemgetter
from typing import Any, Callable, Literal, Optional, Union, cast
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
InvalidToolCall,
SystemMessage,
SystemMessageChunk,
ToolCall,
ToolMessage,
ToolMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
from langchain_core.runnables import Runnable
from langchain_core.messages.tool import ToolCallChunk
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
make_invalid_tool_call,
parse_tool_call,
)
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
LLMResult,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import model_validator
from langchain_core.utils.function_calling import (
convert_to_json_schema,
convert_to_openai_tool,
)
from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
from langchain_huggingface.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain_huggingface.llms.huggingface_pipeline import HuggingFacePipeline
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."""
from ..llms.huggingface_endpoint import HuggingFaceEndpoint
from ..llms.huggingface_pipeline import HuggingFacePipeline
@dataclass
@@ -56,66 +79,143 @@ class TGI_MESSAGE:
tool_calls: list[dict]
def _convert_message_to_chat_message(
message: BaseMessage,
def _lc_tool_call_to_hf_tool_call(tool_call: ToolCall) -> dict:
return {
"type": "function",
"id": tool_call["id"],
"function": {
"name": tool_call["name"],
"arguments": json.dumps(tool_call["args"]),
},
}
def _lc_invalid_tool_call_to_hf_tool_call(
invalid_tool_call: InvalidToolCall,
) -> dict:
return {
"type": "function",
"id": invalid_tool_call["id"],
"function": {
"name": invalid_tool_call["name"],
"arguments": invalid_tool_call["args"],
},
}
def _convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a LangChain message to a dictionary.
Args:
message: The LangChain message.
Returns:
The dictionary.
"""
message_dict: dict[str, Any]
if isinstance(message, ChatMessage):
return dict(role=message.role, content=message.content)
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
return dict(role="user", content=message.content)
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
if "tool_calls" in message.additional_kwargs:
tool_calls = [
{
"function": {
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"],
}
}
for tc in message.additional_kwargs["tool_calls"]
message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
# If function call only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
if message.tool_calls or message.invalid_tool_calls:
message_dict["tool_calls"] = [
_lc_tool_call_to_hf_tool_call(tc) for tc in message.tool_calls
] + [
_lc_invalid_tool_call_to_hf_tool_call(tc)
for tc in message.invalid_tool_calls
]
elif "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
# If tool calls only, content is None not empty string
if "tool_calls" in message_dict and message_dict["content"] == "":
message_dict["content"] = None
else:
tool_calls = None
return {
"role": "assistant",
"content": message.content,
"tool_calls": tool_calls,
}
pass
elif isinstance(message, SystemMessage):
return dict(role="system", content=message.content)
elif isinstance(message, ToolMessage):
return {
"role": "tool",
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
elif isinstance(message, ToolMessage):
message_dict = {
"role": "tool",
"content": message.content,
"tool_call_id": message.tool_call_id,
}
else:
raise ValueError(f"Got unknown type {message}")
raise TypeError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict
def _convert_TGI_message_to_LC_message(
_message: TGI_MESSAGE,
) -> BaseMessage:
role = _message.role
assert role == "assistant", f"Expected role to be 'assistant', got {role}"
content = cast(str, _message.content)
if content is None:
content = ""
additional_kwargs: dict = {}
if tool_calls := _message.tool_calls:
if "arguments" in tool_calls[0]["function"]:
functions = tool_calls[0]["function"].pop("arguments")
tool_calls[0]["function"]["arguments"] = json.dumps(
functions, ensure_ascii=False
)
additional_kwargs["tool_calls"] = tool_calls
return AIMessage(content=content, additional_kwargs=additional_kwargs)
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
"""Convert a dictionary to a LangChain message.
Args:
_dict: The dictionary.
Returns:
The LangChain message.
"""
role = _dict.get("role")
if role == "user":
return HumanMessage(content=_dict.get("content", ""))
elif role == "assistant":
content = _dict.get("content", "") or ""
additional_kwargs: dict = {}
if function_call := _dict.get("function_call"):
additional_kwargs["function_call"] = dict(function_call)
tool_calls = []
invalid_tool_calls = []
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
for raw_tool_call in raw_tool_calls:
try:
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
except Exception as e:
invalid_tool_calls.append(
dict(make_invalid_tool_call(raw_tool_call, str(e)))
)
return AIMessage(
content=content,
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
elif role == "system":
return SystemMessage(content=_dict.get("content", ""))
elif role == "function":
return FunctionMessage(
content=_dict.get("content", ""), name=_dict.get("name", "")
)
elif role == "tool":
additional_kwargs = {}
if "name" in _dict:
additional_kwargs["name"] = _dict["name"]
return ToolMessage(
content=_dict.get("content", ""),
tool_call_id=_dict.get("tool_call_id", ""),
additional_kwargs=additional_kwargs,
)
else:
return ChatMessage(content=_dict.get("content", ""), role=role or "")
def _is_huggingface_hub(llm: Any) -> bool:
try:
from langchain_community.llms.huggingface_hub import ( # type: ignore[import-not-found]
HuggingFaceHub,
from langchain_community.llms.huggingface_hub import (
HuggingFaceHub, # type: ignore[import-not-found]
)
return isinstance(llm, HuggingFaceHub)
@@ -124,10 +224,69 @@ def _is_huggingface_hub(llm: Any) -> bool:
return False
def _convert_chunk_to_message_chunk(
chunk: Mapping[str, Any], default_class: type[BaseMessageChunk]
) -> BaseMessageChunk:
choice = chunk["choices"][0]
_dict = choice["delta"]
role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "")
additional_kwargs: dict = {}
tool_call_chunks: list[ToolCallChunk] = []
if _dict.get("function_call"):
function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None:
function_call["name"] = ""
additional_kwargs["function_call"] = function_call
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
for rtc in raw_tool_calls:
try:
tool_call_chunks.append(
create_tool_call_chunk(
name=rtc["function"].get("name"),
args=rtc["function"].get("arguments"),
id=rtc.get("id"),
index=rtc.get("index"),
)
)
except KeyError:
pass
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
if usage := chunk.get("usage"):
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
usage_metadata = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": usage.get("total_tokens", input_tokens + output_tokens),
}
else:
usage_metadata = None
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
usage_metadata=usage_metadata, # type: ignore[arg-type]
)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"])
elif role == "tool" or default_class == ToolMessageChunk:
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content) # type: ignore
def _is_huggingface_textgen_inference(llm: Any) -> bool:
try:
from langchain_community.llms.huggingface_text_gen_inference import ( # type: ignore[import-not-found]
HuggingFaceTextGenInference,
from langchain_community.llms.huggingface_text_gen_inference import (
HuggingFaceTextGenInference, # type: ignore[import-not-found]
)
return isinstance(llm, HuggingFaceTextGenInference)
@@ -172,11 +331,11 @@ class ChatHuggingFace(BaseChatModel):
'HuggingFacePipeline' LLM to be used.
Key init args — client params:
custom_get_token_ids: Optional[Callable[[str], List[int]]]
custom_get_token_ids: Optional[Callable[[str], list[int]]]
Optional encoder to use for counting tokens.
metadata: Optional[Dict[str, Any]]
metadata: Optional[dict[str, Any]]
Metadata to add to the run trace.
tags: Optional[List[str]]
tags: Optional[list[str]]
Tags to add to the run trace.
tokenizer: Any
verbose: bool
@@ -307,24 +466,43 @@ class ChatHuggingFace(BaseChatModel):
llm: Any
"""LLM, must be of type HuggingFaceTextGenInference, HuggingFaceEndpoint,
HuggingFaceHub, or HuggingFacePipeline."""
# TODO: Is system_message used anywhere?
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
tokenizer: Any = None
"""Tokenizer for the model. Only used for HuggingFacePipeline."""
model_id: Optional[str] = None
"""Model ID for the model. Only used for HuggingFaceEndpoint."""
temperature: Optional[float] = None
"""What sampling temperature to use."""
stop: Optional[Union[str, list[str]]] = Field(default=None, alias="stop_sequences")
"""Default stop sequences."""
presence_penalty: Optional[float] = None
"""Penalizes repeated tokens."""
frequency_penalty: Optional[float] = None
"""Penalizes repeated tokens according to frequency."""
seed: Optional[int] = None
"""Seed for generation"""
logprobs: Optional[bool] = None
"""Whether to return logprobs."""
top_logprobs: Optional[int] = None
"""Number of most likely tokens to return at each token position, each with
an associated log probability. `logprobs` must be set to true
if this parameter is used."""
logit_bias: Optional[dict[int, int]] = None
"""Modify the likelihood of specified tokens appearing in the completion."""
streaming: bool = False
"""Whether to stream the results or not."""
n: Optional[int] = None
"""Number of chat completions to generate for each prompt."""
top_p: Optional[float] = None
"""Total probability mass of tokens to consider at each step."""
max_tokens: Optional[int] = None
"""Maximum number of tokens to generate."""
model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
from transformers import AutoTokenizer # type: ignore[import]
self._resolve_model_id()
self.tokenizer = (
AutoTokenizer.from_pretrained(self.model_id)
if self.tokenizer is None
else self.tokenizer
)
@model_validator(mode="after")
def validate_llm(self) -> Self:
if (
@@ -340,17 +518,30 @@ class ChatHuggingFace(BaseChatModel):
)
return self
def _create_chat_result(self, response: TGI_RESPONSE) -> ChatResult:
def _create_chat_result(self, response: dict) -> ChatResult:
generations = []
finish_reason = response.choices[0].finish_reason
gen = ChatGeneration(
message=_convert_TGI_message_to_LC_message(response.choices[0].message),
generation_info={"finish_reason": finish_reason},
)
generations.append(gen)
token_usage = response.usage
model_object = self.llm.inference_server_url
llm_output = {"token_usage": token_usage, "model": model_object}
token_usage = response.get("usage", {})
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
if token_usage and isinstance(message, AIMessage):
message.usage_metadata = {
"input_tokens": token_usage.get("prompt_tokens", 0),
"output_tokens": token_usage.get("completion_tokens", 0),
"total_tokens": token_usage.get("total_tokens", 0),
}
generation_info = dict(finish_reason=res.get("finish_reason"))
if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"]
gen = ChatGeneration(
message=message,
generation_info=generation_info,
)
generations.append(gen)
llm_output = {
"token_usage": token_usage,
"model_name": self.model_id,
"system_fingerprint": response.get("system_fingerprint", ""),
}
return ChatResult(generations=generations, llm_output=llm_output)
def _generate(
@@ -358,18 +549,38 @@ class ChatHuggingFace(BaseChatModel):
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if _is_huggingface_textgen_inference(self.llm):
message_dicts = self._create_message_dicts(messages, stop)
message_dicts, params = self._create_message_dicts(messages, stop)
answer = self.llm.client.chat(messages=message_dicts, **kwargs)
return self._create_chat_result(answer)
elif _is_huggingface_endpoint(self.llm):
message_dicts = self._create_message_dicts(messages, stop)
answer = self.llm.client.chat_completion(messages=message_dicts, **kwargs)
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {
"stop": stop,
**params,
**({"stream": stream} if stream is not None else {}),
**kwargs,
}
answer = self.llm.client.chat_completion(messages=message_dicts, **params)
return self._create_chat_result(answer)
else:
llm_input = self._to_chat_prompt(messages)
if should_stream:
stream_iter = self.llm._stream(
llm_input, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
llm_result = self.llm._generate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
@@ -380,12 +591,36 @@ class ChatHuggingFace(BaseChatModel):
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
if _is_huggingface_textgen_inference(self.llm):
message_dicts = self._create_message_dicts(messages, stop)
message_dicts, params = self._create_message_dicts(messages, stop)
answer = await self.llm.async_client.chat(messages=message_dicts, **kwargs)
return self._create_chat_result(answer)
elif _is_huggingface_endpoint(self.llm):
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {
**params,
**({"stream": stream} if stream is not None else {}),
**kwargs,
}
answer = await self.llm.async_client.chat_completion(
messages=message_dicts, **params
)
return self._create_chat_result(answer)
elif _is_huggingface_pipeline(self.llm):
raise NotImplementedError(
"async generation is not supported with HuggingFacePipeline"
)
else:
llm_input = self._to_chat_prompt(messages)
llm_result = await self.llm._agenerate(
@@ -393,6 +628,93 @@ class ChatHuggingFace(BaseChatModel):
)
return self._to_chat_result(llm_result)
def _stream(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
if _is_huggingface_endpoint(self.llm):
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
for chunk in self.llm.client.chat_completion(
messages=message_dicts, **params
):
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
message_chunk = _convert_chunk_to_message_chunk(
chunk, default_chunk_class
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
generation_info["model_name"] = self.model_id
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = message_chunk.__class__
generation_chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
)
yield generation_chunk
else:
llm_input = self._to_chat_prompt(messages)
stream_iter = self.llm._stream(
llm_input, stop=stop, run_manager=run_manager, **kwargs
)
for chunk in stream_iter: # chunk is a GenerationChunk
chat_chunk = ChatGenerationChunk(
message=AIMessageChunk(content=chunk.text),
generation_info=chunk.generation_info,
)
yield chat_chunk
async def _astream(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
async for chunk in await self.llm.async_client.chat_completion(
messages=message_dicts, **params
):
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
generation_info["model_name"] = self.model_id
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = message_chunk.__class__
generation_chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
if run_manager:
await run_manager.on_llm_new_token(
token=generation_chunk.text,
chunk=generation_chunk,
logprobs=logprobs,
)
yield generation_chunk
def _to_chat_prompt(
self,
messages: list[BaseMessage],
@@ -451,8 +773,18 @@ class ChatHuggingFace(BaseChatModel):
elif _is_huggingface_textgen_inference(self.llm):
endpoint_url: Optional[str] = self.llm.inference_server_url
elif _is_huggingface_pipeline(self.llm):
from transformers import AutoTokenizer # type: ignore[import]
self.tokenizer = (
AutoTokenizer.from_pretrained(self.model_id)
if self.tokenizer is None
else self.tokenizer
)
self.model_id = self.llm.model_id
return
elif _is_huggingface_endpoint(self.llm):
self.model_id = self.llm.repo_id or self.llm.model
return
else:
endpoint_url = self.llm.endpoint_url
available_endpoints = list_inference_endpoints("*")
@@ -525,11 +857,153 @@ class ChatHuggingFace(BaseChatModel):
kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs)
def with_structured_output(
self,
schema: Optional[Union[dict, type[BaseModel]]] = None,
*,
method: Literal[
"function_calling", "json_mode", "json_schema"
] = "function_calling",
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema:
The output schema. Can be passed in as:
- an OpenAI function/tool schema,
- a JSON Schema,
- a typedDict class (support added in 0.1.7),
Pydantic class is currently supported.
method: The method for steering model generation, one of:
- "function_calling": uses tool-calling features.
- "json_schema": uses dedicated structured output features.
- "json_mode": uses JSON mode.
include_raw:
If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
then both the raw model response (a BaseMessage) and the parsed model
response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
Returns:
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs
an instance of ``schema`` (i.e., a Pydantic object).
Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
If ``include_raw`` is True, then Runnable outputs a dict with keys:
- ``"raw"``: BaseMessage
- ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
- ``"parsing_error"``: Optional[BaseException]
""" # noqa: E501
_ = kwargs.pop("strict", None)
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
if method == "function_calling":
if schema is None:
raise ValueError(
"schema must be specified when method is 'function_calling'. "
"Received None."
)
formatted_tool = convert_to_openai_tool(schema)
tool_name = formatted_tool["function"]["name"]
llm = self.bind_tools(
[schema],
tool_choice=tool_name,
ls_structured_output_format={
"kwargs": {"method": "function_calling"},
"schema": formatted_tool,
},
)
if is_pydantic_schema:
raise NotImplementedError(
"Pydantic schema is not supported for function calling"
)
else:
output_parser: Union[JsonOutputKeyToolsParser, JsonOutputParser] = (
JsonOutputKeyToolsParser(key_name=tool_name, first_tool_only=True)
)
elif method == "json_schema":
if schema is None:
raise ValueError(
"schema must be specified when method is 'json_schema'. "
"Received None."
)
formatted_schema = convert_to_json_schema(schema)
llm = self.bind(
response_format={"type": "json_object", "schema": formatted_schema},
ls_structured_output_format={
"kwargs": {"method": "json_schema"},
"schema": schema,
},
)
output_parser: Union[ # type: ignore[no-redef]
JsonOutputKeyToolsParser, JsonOutputParser
] = JsonOutputParser() # type: ignore[arg-type]
elif method == "json_mode":
llm = self.bind(
response_format={"type": "json_object"},
ls_structured_output_format={
"kwargs": {"method": "json_mode"},
"schema": schema,
},
)
output_parser: Union[ # type: ignore[no-redef]
JsonOutputKeyToolsParser, JsonOutputParser
] = JsonOutputParser() # type: ignore[arg-type]
else:
raise ValueError(
f"Unrecognized method argument. Expected one of 'function_calling' or "
f"'json_mode'. Received: '{method}'"
)
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
def _create_message_dicts(
self, messages: list[BaseMessage], stop: Optional[list[str]]
) -> list[dict[Any, Any]]:
message_dicts = [_convert_message_to_chat_message(m) for m in messages]
return message_dicts
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
params = self._default_params
if stop is not None:
params["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params
@property
def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling Hugging Face
Inference Providers API."""
params = {
"model": self.model_id,
"stream": self.streaming,
"n": self.n,
"temperature": self.temperature,
"stop": self.stop,
**(self.model_kwargs if self.model_kwargs else {}),
}
if self.max_tokens is not None:
params["max_tokens"] = self.max_tokens
return params
@property
def _llm_type(self) -> str: