huggingface[patch]: ruff fixes and rules (#31912)

* bump ruff deps
* add more thorough ruff rules
* fix said rules
This commit is contained in:
Mason Daugherty
2025-07-08 10:07:57 -04:00
committed by GitHub
parent 06ab2972e3
commit 750721b4c3
15 changed files with 283 additions and 196 deletions

View File

@@ -1,5 +1,8 @@
"""Hugging Face Chat Wrapper."""
from __future__ import annotations
import contextlib
import json
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from dataclasses import dataclass
@@ -58,8 +61,8 @@ from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
from ..llms.huggingface_endpoint import HuggingFaceEndpoint
from ..llms.huggingface_pipeline import HuggingFacePipeline
from langchain_huggingface.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain_huggingface.llms.huggingface_pipeline import HuggingFacePipeline
@dataclass
@@ -111,6 +114,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
Returns:
The dictionary.
"""
message_dict: dict[str, Any]
if isinstance(message, ChatMessage):
@@ -153,7 +157,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
"tool_call_id": message.tool_call_id,
}
else:
raise TypeError(f"Got unknown type {message}")
msg = f"Got unknown type {message}"
raise TypeError(msg)
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict
@@ -167,11 +172,12 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
Returns:
The LangChain message.
"""
role = _dict.get("role")
if role == "user":
return HumanMessage(content=_dict.get("content", ""))
elif role == "assistant":
if role == "assistant":
content = _dict.get("content", "") or ""
additional_kwargs: dict = {}
if function_call := _dict.get("function_call"):
@@ -193,13 +199,13 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
elif role == "system":
if role == "system":
return SystemMessage(content=_dict.get("content", ""))
elif role == "function":
if role == "function":
return FunctionMessage(
content=_dict.get("content", ""), name=_dict.get("name", "")
)
elif role == "tool":
if role == "tool":
additional_kwargs = {}
if "name" in _dict:
additional_kwargs["name"] = _dict["name"]
@@ -208,8 +214,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
tool_call_id=_dict.get("tool_call_id", ""),
additional_kwargs=additional_kwargs,
)
else:
return ChatMessage(content=_dict.get("content", ""), role=role or "")
return ChatMessage(content=_dict.get("content", ""), role=role or "")
def _is_huggingface_hub(llm: Any) -> bool:
@@ -241,7 +246,7 @@ def _convert_chunk_to_message_chunk(
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
for rtc in raw_tool_calls:
try:
with contextlib.suppress(KeyError):
tool_call_chunks.append(
create_tool_call_chunk(
name=rtc["function"].get("name"),
@@ -250,11 +255,9 @@ def _convert_chunk_to_message_chunk(
index=rtc.get("index"),
)
)
except KeyError:
pass
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
if role == "assistant" or default_class == AIMessageChunk:
if usage := chunk.get("usage"):
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
@@ -271,16 +274,15 @@ def _convert_chunk_to_message_chunk(
tool_call_chunks=tool_call_chunks,
usage_metadata=usage_metadata, # type: ignore[arg-type]
)
elif role == "system" or default_class == SystemMessageChunk:
if role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role == "function" or default_class == FunctionMessageChunk:
if role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"])
elif role == "tool" or default_class == ToolMessageChunk:
if role == "tool" or default_class == ToolMessageChunk:
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
elif role or default_class == ChatMessageChunk:
if role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content) # type: ignore
return default_class(content=content) # type: ignore[call-arg]
def _is_huggingface_textgen_inference(llm: Any) -> bool:
@@ -304,7 +306,7 @@ def _is_huggingface_pipeline(llm: Any) -> bool:
class ChatHuggingFace(BaseChatModel):
"""Hugging Face LLM's as ChatModels.
r"""Hugging Face LLM's as ChatModels.
Works with `HuggingFaceTextGenInference`, `HuggingFaceEndpoint`,
`HuggingFaceHub`, and `HuggingFacePipeline` LLMs.
@@ -484,7 +486,7 @@ class ChatHuggingFace(BaseChatModel):
"""Whether to return logprobs."""
top_logprobs: Optional[int] = None
"""Number of most likely tokens to return at each token position, each with
an associated log probability. `logprobs` must be set to true
an associated log probability. `logprobs` must be set to true
if this parameter is used."""
logit_bias: Optional[dict[int, int]] = None
"""Modify the likelihood of specified tokens appearing in the completion."""
@@ -511,11 +513,12 @@ class ChatHuggingFace(BaseChatModel):
and not _is_huggingface_endpoint(self.llm)
and not _is_huggingface_pipeline(self.llm)
):
raise TypeError(
msg = (
"Expected llm to be one of HuggingFaceTextGenInference, "
"HuggingFaceEndpoint, HuggingFaceHub, HuggingFacePipeline "
f"received {type(self.llm)}"
)
raise TypeError(msg)
return self
def _create_chat_result(self, response: dict) -> ChatResult:
@@ -529,7 +532,7 @@ class ChatHuggingFace(BaseChatModel):
"output_tokens": token_usage.get("completion_tokens", 0),
"total_tokens": token_usage.get("total_tokens", 0),
}
generation_info = dict(finish_reason=res.get("finish_reason"))
generation_info = {"finish_reason": res.get("finish_reason")}
if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"]
gen = ChatGeneration(
@@ -558,7 +561,7 @@ class ChatHuggingFace(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop)
answer = self.llm.client.chat(messages=message_dicts, **kwargs)
return self._create_chat_result(answer)
elif _is_huggingface_endpoint(self.llm):
if _is_huggingface_endpoint(self.llm):
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
@@ -573,18 +576,17 @@ class ChatHuggingFace(BaseChatModel):
}
answer = self.llm.client.chat_completion(messages=message_dicts, **params)
return self._create_chat_result(answer)
else:
llm_input = self._to_chat_prompt(messages)
llm_input = self._to_chat_prompt(messages)
if should_stream:
stream_iter = self.llm._stream(
llm_input, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
llm_result = self.llm._generate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
if should_stream:
stream_iter = self.llm._stream(
llm_input, stop=stop, run_manager=run_manager, **kwargs
)
return self._to_chat_result(llm_result)
return generate_from_stream(stream_iter)
llm_result = self.llm._generate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
return self._to_chat_result(llm_result)
async def _agenerate(
self,
@@ -598,7 +600,7 @@ class ChatHuggingFace(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop)
answer = await self.llm.async_client.chat(messages=message_dicts, **kwargs)
return self._create_chat_result(answer)
elif _is_huggingface_endpoint(self.llm):
if _is_huggingface_endpoint(self.llm):
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
@@ -616,17 +618,14 @@ class ChatHuggingFace(BaseChatModel):
messages=message_dicts, **params
)
return self._create_chat_result(answer)
elif _is_huggingface_pipeline(self.llm):
raise NotImplementedError(
"async generation is not supported with HuggingFacePipeline"
)
else:
llm_input = self._to_chat_prompt(messages)
llm_result = await self.llm._agenerate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
return self._to_chat_result(llm_result)
if _is_huggingface_pipeline(self.llm):
msg = "async generation is not supported with HuggingFacePipeline"
raise NotImplementedError(msg)
llm_input = self._to_chat_prompt(messages)
llm_result = await self.llm._agenerate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
return self._to_chat_result(llm_result)
def _stream(
self,
@@ -721,10 +720,12 @@ class ChatHuggingFace(BaseChatModel):
) -> str:
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
if not messages:
raise ValueError("At least one HumanMessage must be provided!")
msg = "At least one HumanMessage must be provided!"
raise ValueError(msg)
if not isinstance(messages[-1], HumanMessage):
raise ValueError("Last message must be a HumanMessage!")
msg = "Last message must be a HumanMessage!"
raise ValueError(msg)
messages_dicts = [self._to_chatml_format(m) for m in messages]
@@ -734,7 +735,6 @@ class ChatHuggingFace(BaseChatModel):
def _to_chatml_format(self, message: BaseMessage) -> dict:
"""Convert LangChain message to ChatML format."""
if isinstance(message, SystemMessage):
role = "system"
elif isinstance(message, AIMessage):
@@ -742,7 +742,8 @@ class ChatHuggingFace(BaseChatModel):
elif isinstance(message, HumanMessage):
role = "user"
else:
raise ValueError(f"Unknown message type: {type(message)}")
msg = f"Unknown message type: {type(message)}"
raise ValueError(msg)
return {"role": role, "content": message.content}
@@ -761,8 +762,7 @@ class ChatHuggingFace(BaseChatModel):
)
def _resolve_model_id(self) -> None:
"""Resolve the model_id from the LLM's inference_server_url"""
"""Resolve the model_id from the LLM's inference_server_url."""
from huggingface_hub import list_inference_endpoints # type: ignore[import]
if _is_huggingface_hub(self.llm) or (
@@ -770,9 +770,9 @@ class ChatHuggingFace(BaseChatModel):
):
self.model_id = self.llm.repo_id
return
elif _is_huggingface_textgen_inference(self.llm):
if _is_huggingface_textgen_inference(self.llm):
endpoint_url: Optional[str] = self.llm.inference_server_url
elif _is_huggingface_pipeline(self.llm):
if _is_huggingface_pipeline(self.llm):
from transformers import AutoTokenizer # type: ignore[import]
self.model_id = self.model_id or self.llm.model_id
@@ -782,29 +782,29 @@ class ChatHuggingFace(BaseChatModel):
else self.tokenizer
)
return
elif _is_huggingface_endpoint(self.llm):
if _is_huggingface_endpoint(self.llm):
self.model_id = self.llm.repo_id or self.llm.model
return
else:
endpoint_url = self.llm.endpoint_url
endpoint_url = self.llm.endpoint_url
available_endpoints = list_inference_endpoints("*")
for endpoint in available_endpoints:
if endpoint.url == endpoint_url:
self.model_id = endpoint.repository
if not self.model_id:
raise ValueError(
msg = (
"Failed to resolve model_id:"
f"Could not find model id for inference server: {endpoint_url}"
"Make sure that your Hugging Face token has access to the endpoint."
)
raise ValueError(msg)
def bind_tools(
self,
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
*,
tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required"], bool]
Union[dict, str, Literal["auto", "none", "required"], bool] # noqa: PYI051
] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
@@ -823,15 +823,16 @@ class ChatHuggingFace(BaseChatModel):
{"type": "function", "function": {"name": <<tool_name>>}}.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
if tool_choice is not None and tool_choice:
if len(formatted_tools) != 1:
raise ValueError(
msg = (
"When specifying `tool_choice`, you must provide exactly one "
f"tool. Received {len(formatted_tools)} tools."
)
raise ValueError(msg)
if isinstance(tool_choice, str):
if tool_choice not in ("auto", "none", "required"):
tool_choice = {
@@ -845,15 +846,17 @@ class ChatHuggingFace(BaseChatModel):
formatted_tools[0]["function"]["name"]
!= tool_choice["function"]["name"]
):
raise ValueError(
msg = (
f"Tool choice {tool_choice} was specified, but the only "
f"provided tool was {formatted_tools[0]['function']['name']}."
)
raise ValueError(msg)
else:
raise ValueError(
msg = (
f"Unrecognized tool_choice type. Expected str, bool or dict. "
f"Received: {tool_choice}"
)
raise ValueError(msg)
kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs)
@@ -892,6 +895,11 @@ class ChatHuggingFace(BaseChatModel):
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
kwargs:
Additional parameters to pass to the underlying LLM's
:meth:`langchain_core.language_models.chat.BaseChatModel.bind`
method, such as `response_format` or `ls_structured_output_format`.
Returns:
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
@@ -908,14 +916,16 @@ class ChatHuggingFace(BaseChatModel):
""" # noqa: E501
_ = kwargs.pop("strict", None)
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
msg = f"Received unsupported arguments {kwargs}"
raise ValueError(msg)
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
if method == "function_calling":
if schema is None:
raise ValueError(
msg = (
"schema must be specified when method is 'function_calling'. "
"Received None."
)
raise ValueError(msg)
formatted_tool = convert_to_openai_tool(schema)
tool_name = formatted_tool["function"]["name"]
llm = self.bind_tools(
@@ -927,19 +937,18 @@ class ChatHuggingFace(BaseChatModel):
},
)
if is_pydantic_schema:
raise NotImplementedError(
"Pydantic schema is not supported for function calling"
)
else:
output_parser: Union[JsonOutputKeyToolsParser, JsonOutputParser] = (
JsonOutputKeyToolsParser(key_name=tool_name, first_tool_only=True)
)
msg = "Pydantic schema is not supported for function calling"
raise NotImplementedError(msg)
output_parser: Union[JsonOutputKeyToolsParser, JsonOutputParser] = (
JsonOutputKeyToolsParser(key_name=tool_name, first_tool_only=True)
)
elif method == "json_schema":
if schema is None:
raise ValueError(
msg = (
"schema must be specified when method is 'json_schema'. "
"Received None."
)
raise ValueError(msg)
formatted_schema = convert_to_json_schema(schema)
llm = self.bind(
response_format={"type": "json_object", "schema": formatted_schema},
@@ -963,10 +972,11 @@ class ChatHuggingFace(BaseChatModel):
JsonOutputKeyToolsParser, JsonOutputParser
] = JsonOutputParser() # type: ignore[arg-type]
else:
raise ValueError(
msg = (
f"Unrecognized method argument. Expected one of 'function_calling' or "
f"'json_mode'. Received: '{method}'"
)
raise ValueError(msg)
if include_raw:
parser_assign = RunnablePassthrough.assign(
@@ -977,8 +987,7 @@ class ChatHuggingFace(BaseChatModel):
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
return llm | output_parser
def _create_message_dicts(
self, messages: list[BaseMessage], stop: Optional[list[str]]
@@ -991,8 +1000,7 @@ class ChatHuggingFace(BaseChatModel):
@property
def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling Hugging Face
Inference Providers API."""
"""Get default parameters for calling Hugging Face Inference Providers API."""
params = {
"model": self.model_id,
"stream": self.streaming,