mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-08 14:31:55 +00:00
huggingface[patch]: ruff fixes and rules (#31912)
* bump ruff deps * add more thorough ruff rules * fix said rules
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
"""Hugging Face Chat Wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
@@ -58,8 +61,8 @@ from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||
from ..llms.huggingface_pipeline import HuggingFacePipeline
|
||||
from langchain_huggingface.llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||
from langchain_huggingface.llms.huggingface_pipeline import HuggingFacePipeline
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -111,6 +114,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
|
||||
Returns:
|
||||
The dictionary.
|
||||
|
||||
"""
|
||||
message_dict: dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
@@ -153,7 +157,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
"tool_call_id": message.tool_call_id,
|
||||
}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
msg = f"Got unknown type {message}"
|
||||
raise TypeError(msg)
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
@@ -167,11 +172,12 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
|
||||
Returns:
|
||||
The LangChain message.
|
||||
|
||||
"""
|
||||
role = _dict.get("role")
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict.get("content", ""))
|
||||
elif role == "assistant":
|
||||
if role == "assistant":
|
||||
content = _dict.get("content", "") or ""
|
||||
additional_kwargs: dict = {}
|
||||
if function_call := _dict.get("function_call"):
|
||||
@@ -193,13 +199,13 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
tool_calls=tool_calls,
|
||||
invalid_tool_calls=invalid_tool_calls,
|
||||
)
|
||||
elif role == "system":
|
||||
if role == "system":
|
||||
return SystemMessage(content=_dict.get("content", ""))
|
||||
elif role == "function":
|
||||
if role == "function":
|
||||
return FunctionMessage(
|
||||
content=_dict.get("content", ""), name=_dict.get("name", "")
|
||||
)
|
||||
elif role == "tool":
|
||||
if role == "tool":
|
||||
additional_kwargs = {}
|
||||
if "name" in _dict:
|
||||
additional_kwargs["name"] = _dict["name"]
|
||||
@@ -208,8 +214,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
tool_call_id=_dict.get("tool_call_id", ""),
|
||||
additional_kwargs=additional_kwargs,
|
||||
)
|
||||
else:
|
||||
return ChatMessage(content=_dict.get("content", ""), role=role or "")
|
||||
return ChatMessage(content=_dict.get("content", ""), role=role or "")
|
||||
|
||||
|
||||
def _is_huggingface_hub(llm: Any) -> bool:
|
||||
@@ -241,7 +246,7 @@ def _convert_chunk_to_message_chunk(
|
||||
if raw_tool_calls := _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||
for rtc in raw_tool_calls:
|
||||
try:
|
||||
with contextlib.suppress(KeyError):
|
||||
tool_call_chunks.append(
|
||||
create_tool_call_chunk(
|
||||
name=rtc["function"].get("name"),
|
||||
@@ -250,11 +255,9 @@ def _convert_chunk_to_message_chunk(
|
||||
index=rtc.get("index"),
|
||||
)
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
if role == "assistant" or default_class == AIMessageChunk:
|
||||
if usage := chunk.get("usage"):
|
||||
input_tokens = usage.get("prompt_tokens", 0)
|
||||
output_tokens = usage.get("completion_tokens", 0)
|
||||
@@ -271,16 +274,15 @@ def _convert_chunk_to_message_chunk(
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
||||
)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
if role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role == "function" or default_class == FunctionMessageChunk:
|
||||
if role == "function" or default_class == FunctionMessageChunk:
|
||||
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||
elif role == "tool" or default_class == ToolMessageChunk:
|
||||
if role == "tool" or default_class == ToolMessageChunk:
|
||||
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
if role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
else:
|
||||
return default_class(content=content) # type: ignore
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def _is_huggingface_textgen_inference(llm: Any) -> bool:
|
||||
@@ -304,7 +306,7 @@ def _is_huggingface_pipeline(llm: Any) -> bool:
|
||||
|
||||
|
||||
class ChatHuggingFace(BaseChatModel):
|
||||
"""Hugging Face LLM's as ChatModels.
|
||||
r"""Hugging Face LLM's as ChatModels.
|
||||
|
||||
Works with `HuggingFaceTextGenInference`, `HuggingFaceEndpoint`,
|
||||
`HuggingFaceHub`, and `HuggingFacePipeline` LLMs.
|
||||
@@ -484,7 +486,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
"""Whether to return logprobs."""
|
||||
top_logprobs: Optional[int] = None
|
||||
"""Number of most likely tokens to return at each token position, each with
|
||||
an associated log probability. `logprobs` must be set to true
|
||||
an associated log probability. `logprobs` must be set to true
|
||||
if this parameter is used."""
|
||||
logit_bias: Optional[dict[int, int]] = None
|
||||
"""Modify the likelihood of specified tokens appearing in the completion."""
|
||||
@@ -511,11 +513,12 @@ class ChatHuggingFace(BaseChatModel):
|
||||
and not _is_huggingface_endpoint(self.llm)
|
||||
and not _is_huggingface_pipeline(self.llm)
|
||||
):
|
||||
raise TypeError(
|
||||
msg = (
|
||||
"Expected llm to be one of HuggingFaceTextGenInference, "
|
||||
"HuggingFaceEndpoint, HuggingFaceHub, HuggingFacePipeline "
|
||||
f"received {type(self.llm)}"
|
||||
)
|
||||
raise TypeError(msg)
|
||||
return self
|
||||
|
||||
def _create_chat_result(self, response: dict) -> ChatResult:
|
||||
@@ -529,7 +532,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
"output_tokens": token_usage.get("completion_tokens", 0),
|
||||
"total_tokens": token_usage.get("total_tokens", 0),
|
||||
}
|
||||
generation_info = dict(finish_reason=res.get("finish_reason"))
|
||||
generation_info = {"finish_reason": res.get("finish_reason")}
|
||||
if "logprobs" in res:
|
||||
generation_info["logprobs"] = res["logprobs"]
|
||||
gen = ChatGeneration(
|
||||
@@ -558,7 +561,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
answer = self.llm.client.chat(messages=message_dicts, **kwargs)
|
||||
return self._create_chat_result(answer)
|
||||
elif _is_huggingface_endpoint(self.llm):
|
||||
if _is_huggingface_endpoint(self.llm):
|
||||
if should_stream:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
@@ -573,18 +576,17 @@ class ChatHuggingFace(BaseChatModel):
|
||||
}
|
||||
answer = self.llm.client.chat_completion(messages=message_dicts, **params)
|
||||
return self._create_chat_result(answer)
|
||||
else:
|
||||
llm_input = self._to_chat_prompt(messages)
|
||||
llm_input = self._to_chat_prompt(messages)
|
||||
|
||||
if should_stream:
|
||||
stream_iter = self.llm._stream(
|
||||
llm_input, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
llm_result = self.llm._generate(
|
||||
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
|
||||
if should_stream:
|
||||
stream_iter = self.llm._stream(
|
||||
llm_input, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return self._to_chat_result(llm_result)
|
||||
return generate_from_stream(stream_iter)
|
||||
llm_result = self.llm._generate(
|
||||
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return self._to_chat_result(llm_result)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
@@ -598,7 +600,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
answer = await self.llm.async_client.chat(messages=message_dicts, **kwargs)
|
||||
return self._create_chat_result(answer)
|
||||
elif _is_huggingface_endpoint(self.llm):
|
||||
if _is_huggingface_endpoint(self.llm):
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._astream(
|
||||
@@ -616,17 +618,14 @@ class ChatHuggingFace(BaseChatModel):
|
||||
messages=message_dicts, **params
|
||||
)
|
||||
return self._create_chat_result(answer)
|
||||
|
||||
elif _is_huggingface_pipeline(self.llm):
|
||||
raise NotImplementedError(
|
||||
"async generation is not supported with HuggingFacePipeline"
|
||||
)
|
||||
else:
|
||||
llm_input = self._to_chat_prompt(messages)
|
||||
llm_result = await self.llm._agenerate(
|
||||
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return self._to_chat_result(llm_result)
|
||||
if _is_huggingface_pipeline(self.llm):
|
||||
msg = "async generation is not supported with HuggingFacePipeline"
|
||||
raise NotImplementedError(msg)
|
||||
llm_input = self._to_chat_prompt(messages)
|
||||
llm_result = await self.llm._agenerate(
|
||||
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return self._to_chat_result(llm_result)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
@@ -721,10 +720,12 @@ class ChatHuggingFace(BaseChatModel):
|
||||
) -> str:
|
||||
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
|
||||
if not messages:
|
||||
raise ValueError("At least one HumanMessage must be provided!")
|
||||
msg = "At least one HumanMessage must be provided!"
|
||||
raise ValueError(msg)
|
||||
|
||||
if not isinstance(messages[-1], HumanMessage):
|
||||
raise ValueError("Last message must be a HumanMessage!")
|
||||
msg = "Last message must be a HumanMessage!"
|
||||
raise ValueError(msg)
|
||||
|
||||
messages_dicts = [self._to_chatml_format(m) for m in messages]
|
||||
|
||||
@@ -734,7 +735,6 @@ class ChatHuggingFace(BaseChatModel):
|
||||
|
||||
def _to_chatml_format(self, message: BaseMessage) -> dict:
|
||||
"""Convert LangChain message to ChatML format."""
|
||||
|
||||
if isinstance(message, SystemMessage):
|
||||
role = "system"
|
||||
elif isinstance(message, AIMessage):
|
||||
@@ -742,7 +742,8 @@ class ChatHuggingFace(BaseChatModel):
|
||||
elif isinstance(message, HumanMessage):
|
||||
role = "user"
|
||||
else:
|
||||
raise ValueError(f"Unknown message type: {type(message)}")
|
||||
msg = f"Unknown message type: {type(message)}"
|
||||
raise ValueError(msg)
|
||||
|
||||
return {"role": role, "content": message.content}
|
||||
|
||||
@@ -761,8 +762,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
)
|
||||
|
||||
def _resolve_model_id(self) -> None:
|
||||
"""Resolve the model_id from the LLM's inference_server_url"""
|
||||
|
||||
"""Resolve the model_id from the LLM's inference_server_url."""
|
||||
from huggingface_hub import list_inference_endpoints # type: ignore[import]
|
||||
|
||||
if _is_huggingface_hub(self.llm) or (
|
||||
@@ -770,9 +770,9 @@ class ChatHuggingFace(BaseChatModel):
|
||||
):
|
||||
self.model_id = self.llm.repo_id
|
||||
return
|
||||
elif _is_huggingface_textgen_inference(self.llm):
|
||||
if _is_huggingface_textgen_inference(self.llm):
|
||||
endpoint_url: Optional[str] = self.llm.inference_server_url
|
||||
elif _is_huggingface_pipeline(self.llm):
|
||||
if _is_huggingface_pipeline(self.llm):
|
||||
from transformers import AutoTokenizer # type: ignore[import]
|
||||
|
||||
self.model_id = self.model_id or self.llm.model_id
|
||||
@@ -782,29 +782,29 @@ class ChatHuggingFace(BaseChatModel):
|
||||
else self.tokenizer
|
||||
)
|
||||
return
|
||||
elif _is_huggingface_endpoint(self.llm):
|
||||
if _is_huggingface_endpoint(self.llm):
|
||||
self.model_id = self.llm.repo_id or self.llm.model
|
||||
return
|
||||
else:
|
||||
endpoint_url = self.llm.endpoint_url
|
||||
endpoint_url = self.llm.endpoint_url
|
||||
available_endpoints = list_inference_endpoints("*")
|
||||
for endpoint in available_endpoints:
|
||||
if endpoint.url == endpoint_url:
|
||||
self.model_id = endpoint.repository
|
||||
|
||||
if not self.model_id:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"Failed to resolve model_id:"
|
||||
f"Could not find model id for inference server: {endpoint_url}"
|
||||
"Make sure that your Hugging Face token has access to the endpoint."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[
|
||||
Union[dict, str, Literal["auto", "none", "required"], bool]
|
||||
Union[dict, str, Literal["auto", "none", "required"], bool] # noqa: PYI051
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
@@ -823,15 +823,16 @@ class ChatHuggingFace(BaseChatModel):
|
||||
{"type": "function", "function": {"name": <<tool_name>>}}.
|
||||
**kwargs: Any additional parameters to pass to the
|
||||
:class:`~langchain.runnable.Runnable` constructor.
|
||||
"""
|
||||
|
||||
"""
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
if tool_choice is not None and tool_choice:
|
||||
if len(formatted_tools) != 1:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"When specifying `tool_choice`, you must provide exactly one "
|
||||
f"tool. Received {len(formatted_tools)} tools."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if isinstance(tool_choice, str):
|
||||
if tool_choice not in ("auto", "none", "required"):
|
||||
tool_choice = {
|
||||
@@ -845,15 +846,17 @@ class ChatHuggingFace(BaseChatModel):
|
||||
formatted_tools[0]["function"]["name"]
|
||||
!= tool_choice["function"]["name"]
|
||||
):
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Tool choice {tool_choice} was specified, but the only "
|
||||
f"provided tool was {formatted_tools[0]['function']['name']}."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Unrecognized tool_choice type. Expected str, bool or dict. "
|
||||
f"Received: {tool_choice}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
kwargs["tool_choice"] = tool_choice
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
||||
@@ -892,6 +895,11 @@ class ChatHuggingFace(BaseChatModel):
|
||||
will be caught and returned as well. The final output is always a dict
|
||||
with keys "raw", "parsed", and "parsing_error".
|
||||
|
||||
kwargs:
|
||||
Additional parameters to pass to the underlying LLM's
|
||||
:meth:`langchain_core.language_models.chat.BaseChatModel.bind`
|
||||
method, such as `response_format` or `ls_structured_output_format`.
|
||||
|
||||
Returns:
|
||||
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
|
||||
|
||||
@@ -908,14 +916,16 @@ class ChatHuggingFace(BaseChatModel):
|
||||
""" # noqa: E501
|
||||
_ = kwargs.pop("strict", None)
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
msg = f"Received unsupported arguments {kwargs}"
|
||||
raise ValueError(msg)
|
||||
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
|
||||
if method == "function_calling":
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"schema must be specified when method is 'function_calling'. "
|
||||
"Received None."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
formatted_tool = convert_to_openai_tool(schema)
|
||||
tool_name = formatted_tool["function"]["name"]
|
||||
llm = self.bind_tools(
|
||||
@@ -927,19 +937,18 @@ class ChatHuggingFace(BaseChatModel):
|
||||
},
|
||||
)
|
||||
if is_pydantic_schema:
|
||||
raise NotImplementedError(
|
||||
"Pydantic schema is not supported for function calling"
|
||||
)
|
||||
else:
|
||||
output_parser: Union[JsonOutputKeyToolsParser, JsonOutputParser] = (
|
||||
JsonOutputKeyToolsParser(key_name=tool_name, first_tool_only=True)
|
||||
)
|
||||
msg = "Pydantic schema is not supported for function calling"
|
||||
raise NotImplementedError(msg)
|
||||
output_parser: Union[JsonOutputKeyToolsParser, JsonOutputParser] = (
|
||||
JsonOutputKeyToolsParser(key_name=tool_name, first_tool_only=True)
|
||||
)
|
||||
elif method == "json_schema":
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"schema must be specified when method is 'json_schema'. "
|
||||
"Received None."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
formatted_schema = convert_to_json_schema(schema)
|
||||
llm = self.bind(
|
||||
response_format={"type": "json_object", "schema": formatted_schema},
|
||||
@@ -963,10 +972,11 @@ class ChatHuggingFace(BaseChatModel):
|
||||
JsonOutputKeyToolsParser, JsonOutputParser
|
||||
] = JsonOutputParser() # type: ignore[arg-type]
|
||||
else:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Unrecognized method argument. Expected one of 'function_calling' or "
|
||||
f"'json_mode'. Received: '{method}'"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if include_raw:
|
||||
parser_assign = RunnablePassthrough.assign(
|
||||
@@ -977,8 +987,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
[parser_none], exception_key="parsing_error"
|
||||
)
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
else:
|
||||
return llm | output_parser
|
||||
return llm | output_parser
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: list[BaseMessage], stop: Optional[list[str]]
|
||||
@@ -991,8 +1000,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
|
||||
@property
|
||||
def _default_params(self) -> dict[str, Any]:
|
||||
"""Get the default parameters for calling Hugging Face
|
||||
Inference Providers API."""
|
||||
"""Get default parameters for calling Hugging Face Inference Providers API."""
|
||||
params = {
|
||||
"model": self.model_id,
|
||||
"stream": self.streaming,
|
||||
|
Reference in New Issue
Block a user