huggingface[patch]: ruff fixes and rules (#31912)

* bump ruff deps
* add more thorough ruff rules
* fix said rules
This commit is contained in:
Mason Daugherty 2025-07-08 10:07:57 -04:00 committed by GitHub
parent 06ab2972e3
commit 750721b4c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 283 additions and 196 deletions

View File

@ -87,7 +87,16 @@ select = [
"W", # pycodestyle warning "W", # pycodestyle warning
"YTT", # flake8-2020 "YTT", # flake8-2020
] ]
ignore = ["D100", "D101", "D102", "D103", "D104", "D105", "D107"] ignore = [
"D100", # pydocstyle: Missing docstring in public module
"D101", # pydocstyle: Missing docstring in public class
"D102", # pydocstyle: Missing docstring in public method
"D103", # pydocstyle: Missing docstring in public function
"D104", # pydocstyle: Missing docstring in public package
"D105", # pydocstyle: Missing docstring in magic method
"D107", # pydocstyle: Missing docstring in __init__
"D407", # pydocstyle: Missing-dashed-underline-after-section
]
pyupgrade.keep-runtime-typing = true pyupgrade.keep-runtime-typing = true
[tool.mypy] [tool.mypy]

View File

@ -180,13 +180,14 @@ select = [
"YTT", # flake8-2020 "YTT", # flake8-2020
] ]
ignore = [ ignore = [
"D100", # Missing docstring in public module "D100", # pydocstyle: Missing docstring in public module
"D101", # Missing docstring in public class "D101", # pydocstyle: Missing docstring in public class
"D102", # Missing docstring in public method "D102", # pydocstyle: Missing docstring in public method
"D103", # Missing docstring in public function "D103", # pydocstyle: Missing docstring in public function
"D104", # Missing docstring in public package "D104", # pydocstyle: Missing docstring in public package
"D105", # Missing docstring in magic method "D105", # pydocstyle: Missing docstring in magic method
"D107", # Missing docstring in __init__ "D107", # pydocstyle: Missing docstring in __init__
"D407", # pydocstyle: Missing-dashed-underline-after-section
"COM812", # Messes with the formatter "COM812", # Messes with the formatter
"ISC001", # Messes with the formatter "ISC001", # Messes with the formatter
"PERF203", # Rarely useful "PERF203", # Rarely useful

View File

@ -97,13 +97,14 @@ select = [
"YTT", # flake8-2020 "YTT", # flake8-2020
] ]
ignore = [ ignore = [
"D100", # Missing docstring in public module "D100", # pydocstyle: Missing docstring in public module
"D101", # Missing docstring in public class "D101", # pydocstyle: Missing docstring in public class
"D102", # Missing docstring in public method "D102", # pydocstyle: Missing docstring in public method
"D103", # Missing docstring in public function "D103", # pydocstyle: Missing docstring in public function
"D104", # Missing docstring in public package "D104", # pydocstyle: Missing docstring in public package
"D105", # Missing docstring in magic method "D105", # pydocstyle: Missing docstring in magic method
"D107", # Missing docstring in __init__ "D107", # pydocstyle: Missing docstring in __init__
"D407", # pydocstyle: Missing-dashed-underline-after-section
"D214", # Section over-indented, doesn't play well with reStructuredText "D214", # Section over-indented, doesn't play well with reStructuredText
"COM812", # Messes with the formatter "COM812", # Messes with the formatter
"ISC001", # Messes with the formatter "ISC001", # Messes with the formatter

View File

@ -343,7 +343,6 @@ def function() -> Callable:
"""Dummy function. """Dummy function.
Args: Args:
----
arg1: foo arg1: foo
arg2: one of 'bar', 'baz' arg2: one of 'bar', 'baz'

View File

@ -1,5 +1,8 @@
"""Hugging Face Chat Wrapper.""" """Hugging Face Chat Wrapper."""
from __future__ import annotations
import contextlib
import json import json
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from dataclasses import dataclass from dataclasses import dataclass
@ -58,8 +61,8 @@ from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self from typing_extensions import Self
from ..llms.huggingface_endpoint import HuggingFaceEndpoint from langchain_huggingface.llms.huggingface_endpoint import HuggingFaceEndpoint
from ..llms.huggingface_pipeline import HuggingFacePipeline from langchain_huggingface.llms.huggingface_pipeline import HuggingFacePipeline
@dataclass @dataclass
@ -111,6 +114,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
Returns: Returns:
The dictionary. The dictionary.
""" """
message_dict: dict[str, Any] message_dict: dict[str, Any]
if isinstance(message, ChatMessage): if isinstance(message, ChatMessage):
@ -153,7 +157,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
"tool_call_id": message.tool_call_id, "tool_call_id": message.tool_call_id,
} }
else: else:
raise TypeError(f"Got unknown type {message}") msg = f"Got unknown type {message}"
raise TypeError(msg)
if "name" in message.additional_kwargs: if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"] message_dict["name"] = message.additional_kwargs["name"]
return message_dict return message_dict
@ -167,11 +172,12 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
Returns: Returns:
The LangChain message. The LangChain message.
""" """
role = _dict.get("role") role = _dict.get("role")
if role == "user": if role == "user":
return HumanMessage(content=_dict.get("content", "")) return HumanMessage(content=_dict.get("content", ""))
elif role == "assistant": if role == "assistant":
content = _dict.get("content", "") or "" content = _dict.get("content", "") or ""
additional_kwargs: dict = {} additional_kwargs: dict = {}
if function_call := _dict.get("function_call"): if function_call := _dict.get("function_call"):
@ -193,13 +199,13 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
tool_calls=tool_calls, tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls, invalid_tool_calls=invalid_tool_calls,
) )
elif role == "system": if role == "system":
return SystemMessage(content=_dict.get("content", "")) return SystemMessage(content=_dict.get("content", ""))
elif role == "function": if role == "function":
return FunctionMessage( return FunctionMessage(
content=_dict.get("content", ""), name=_dict.get("name", "") content=_dict.get("content", ""), name=_dict.get("name", "")
) )
elif role == "tool": if role == "tool":
additional_kwargs = {} additional_kwargs = {}
if "name" in _dict: if "name" in _dict:
additional_kwargs["name"] = _dict["name"] additional_kwargs["name"] = _dict["name"]
@ -208,8 +214,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
tool_call_id=_dict.get("tool_call_id", ""), tool_call_id=_dict.get("tool_call_id", ""),
additional_kwargs=additional_kwargs, additional_kwargs=additional_kwargs,
) )
else: return ChatMessage(content=_dict.get("content", ""), role=role or "")
return ChatMessage(content=_dict.get("content", ""), role=role or "")
def _is_huggingface_hub(llm: Any) -> bool: def _is_huggingface_hub(llm: Any) -> bool:
@ -241,7 +246,7 @@ def _convert_chunk_to_message_chunk(
if raw_tool_calls := _dict.get("tool_calls"): if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls additional_kwargs["tool_calls"] = raw_tool_calls
for rtc in raw_tool_calls: for rtc in raw_tool_calls:
try: with contextlib.suppress(KeyError):
tool_call_chunks.append( tool_call_chunks.append(
create_tool_call_chunk( create_tool_call_chunk(
name=rtc["function"].get("name"), name=rtc["function"].get("name"),
@ -250,11 +255,9 @@ def _convert_chunk_to_message_chunk(
index=rtc.get("index"), index=rtc.get("index"),
) )
) )
except KeyError:
pass
if role == "user" or default_class == HumanMessageChunk: if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content) return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk: if role == "assistant" or default_class == AIMessageChunk:
if usage := chunk.get("usage"): if usage := chunk.get("usage"):
input_tokens = usage.get("prompt_tokens", 0) input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0) output_tokens = usage.get("completion_tokens", 0)
@ -271,16 +274,15 @@ def _convert_chunk_to_message_chunk(
tool_call_chunks=tool_call_chunks, tool_call_chunks=tool_call_chunks,
usage_metadata=usage_metadata, # type: ignore[arg-type] usage_metadata=usage_metadata, # type: ignore[arg-type]
) )
elif role == "system" or default_class == SystemMessageChunk: if role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content) return SystemMessageChunk(content=content)
elif role == "function" or default_class == FunctionMessageChunk: if role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"]) return FunctionMessageChunk(content=content, name=_dict["name"])
elif role == "tool" or default_class == ToolMessageChunk: if role == "tool" or default_class == ToolMessageChunk:
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"]) return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
elif role or default_class == ChatMessageChunk: if role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role) return ChatMessageChunk(content=content, role=role)
else: return default_class(content=content) # type: ignore[call-arg]
return default_class(content=content) # type: ignore
def _is_huggingface_textgen_inference(llm: Any) -> bool: def _is_huggingface_textgen_inference(llm: Any) -> bool:
@ -304,7 +306,7 @@ def _is_huggingface_pipeline(llm: Any) -> bool:
class ChatHuggingFace(BaseChatModel): class ChatHuggingFace(BaseChatModel):
"""Hugging Face LLM's as ChatModels. r"""Hugging Face LLM's as ChatModels.
Works with `HuggingFaceTextGenInference`, `HuggingFaceEndpoint`, Works with `HuggingFaceTextGenInference`, `HuggingFaceEndpoint`,
`HuggingFaceHub`, and `HuggingFacePipeline` LLMs. `HuggingFaceHub`, and `HuggingFacePipeline` LLMs.
@ -484,7 +486,7 @@ class ChatHuggingFace(BaseChatModel):
"""Whether to return logprobs.""" """Whether to return logprobs."""
top_logprobs: Optional[int] = None top_logprobs: Optional[int] = None
"""Number of most likely tokens to return at each token position, each with """Number of most likely tokens to return at each token position, each with
an associated log probability. `logprobs` must be set to true an associated log probability. `logprobs` must be set to true
if this parameter is used.""" if this parameter is used."""
logit_bias: Optional[dict[int, int]] = None logit_bias: Optional[dict[int, int]] = None
"""Modify the likelihood of specified tokens appearing in the completion.""" """Modify the likelihood of specified tokens appearing in the completion."""
@ -511,11 +513,12 @@ class ChatHuggingFace(BaseChatModel):
and not _is_huggingface_endpoint(self.llm) and not _is_huggingface_endpoint(self.llm)
and not _is_huggingface_pipeline(self.llm) and not _is_huggingface_pipeline(self.llm)
): ):
raise TypeError( msg = (
"Expected llm to be one of HuggingFaceTextGenInference, " "Expected llm to be one of HuggingFaceTextGenInference, "
"HuggingFaceEndpoint, HuggingFaceHub, HuggingFacePipeline " "HuggingFaceEndpoint, HuggingFaceHub, HuggingFacePipeline "
f"received {type(self.llm)}" f"received {type(self.llm)}"
) )
raise TypeError(msg)
return self return self
def _create_chat_result(self, response: dict) -> ChatResult: def _create_chat_result(self, response: dict) -> ChatResult:
@ -529,7 +532,7 @@ class ChatHuggingFace(BaseChatModel):
"output_tokens": token_usage.get("completion_tokens", 0), "output_tokens": token_usage.get("completion_tokens", 0),
"total_tokens": token_usage.get("total_tokens", 0), "total_tokens": token_usage.get("total_tokens", 0),
} }
generation_info = dict(finish_reason=res.get("finish_reason")) generation_info = {"finish_reason": res.get("finish_reason")}
if "logprobs" in res: if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"] generation_info["logprobs"] = res["logprobs"]
gen = ChatGeneration( gen = ChatGeneration(
@ -558,7 +561,7 @@ class ChatHuggingFace(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
answer = self.llm.client.chat(messages=message_dicts, **kwargs) answer = self.llm.client.chat(messages=message_dicts, **kwargs)
return self._create_chat_result(answer) return self._create_chat_result(answer)
elif _is_huggingface_endpoint(self.llm): if _is_huggingface_endpoint(self.llm):
if should_stream: if should_stream:
stream_iter = self._stream( stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs messages, stop=stop, run_manager=run_manager, **kwargs
@ -573,18 +576,17 @@ class ChatHuggingFace(BaseChatModel):
} }
answer = self.llm.client.chat_completion(messages=message_dicts, **params) answer = self.llm.client.chat_completion(messages=message_dicts, **params)
return self._create_chat_result(answer) return self._create_chat_result(answer)
else: llm_input = self._to_chat_prompt(messages)
llm_input = self._to_chat_prompt(messages)
if should_stream: if should_stream:
stream_iter = self.llm._stream( stream_iter = self.llm._stream(
llm_input, stop=stop, run_manager=run_manager, **kwargs llm_input, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
llm_result = self.llm._generate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
) )
return self._to_chat_result(llm_result) return generate_from_stream(stream_iter)
llm_result = self.llm._generate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
return self._to_chat_result(llm_result)
async def _agenerate( async def _agenerate(
self, self,
@ -598,7 +600,7 @@ class ChatHuggingFace(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
answer = await self.llm.async_client.chat(messages=message_dicts, **kwargs) answer = await self.llm.async_client.chat(messages=message_dicts, **kwargs)
return self._create_chat_result(answer) return self._create_chat_result(answer)
elif _is_huggingface_endpoint(self.llm): if _is_huggingface_endpoint(self.llm):
should_stream = stream if stream is not None else self.streaming should_stream = stream if stream is not None else self.streaming
if should_stream: if should_stream:
stream_iter = self._astream( stream_iter = self._astream(
@ -616,17 +618,14 @@ class ChatHuggingFace(BaseChatModel):
messages=message_dicts, **params messages=message_dicts, **params
) )
return self._create_chat_result(answer) return self._create_chat_result(answer)
if _is_huggingface_pipeline(self.llm):
elif _is_huggingface_pipeline(self.llm): msg = "async generation is not supported with HuggingFacePipeline"
raise NotImplementedError( raise NotImplementedError(msg)
"async generation is not supported with HuggingFacePipeline" llm_input = self._to_chat_prompt(messages)
) llm_result = await self.llm._agenerate(
else: prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
llm_input = self._to_chat_prompt(messages) )
llm_result = await self.llm._agenerate( return self._to_chat_result(llm_result)
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
return self._to_chat_result(llm_result)
def _stream( def _stream(
self, self,
@ -721,10 +720,12 @@ class ChatHuggingFace(BaseChatModel):
) -> str: ) -> str:
"""Convert a list of messages into a prompt format expected by wrapped LLM.""" """Convert a list of messages into a prompt format expected by wrapped LLM."""
if not messages: if not messages:
raise ValueError("At least one HumanMessage must be provided!") msg = "At least one HumanMessage must be provided!"
raise ValueError(msg)
if not isinstance(messages[-1], HumanMessage): if not isinstance(messages[-1], HumanMessage):
raise ValueError("Last message must be a HumanMessage!") msg = "Last message must be a HumanMessage!"
raise ValueError(msg)
messages_dicts = [self._to_chatml_format(m) for m in messages] messages_dicts = [self._to_chatml_format(m) for m in messages]
@ -734,7 +735,6 @@ class ChatHuggingFace(BaseChatModel):
def _to_chatml_format(self, message: BaseMessage) -> dict: def _to_chatml_format(self, message: BaseMessage) -> dict:
"""Convert LangChain message to ChatML format.""" """Convert LangChain message to ChatML format."""
if isinstance(message, SystemMessage): if isinstance(message, SystemMessage):
role = "system" role = "system"
elif isinstance(message, AIMessage): elif isinstance(message, AIMessage):
@ -742,7 +742,8 @@ class ChatHuggingFace(BaseChatModel):
elif isinstance(message, HumanMessage): elif isinstance(message, HumanMessage):
role = "user" role = "user"
else: else:
raise ValueError(f"Unknown message type: {type(message)}") msg = f"Unknown message type: {type(message)}"
raise ValueError(msg)
return {"role": role, "content": message.content} return {"role": role, "content": message.content}
@ -761,8 +762,7 @@ class ChatHuggingFace(BaseChatModel):
) )
def _resolve_model_id(self) -> None: def _resolve_model_id(self) -> None:
"""Resolve the model_id from the LLM's inference_server_url""" """Resolve the model_id from the LLM's inference_server_url."""
from huggingface_hub import list_inference_endpoints # type: ignore[import] from huggingface_hub import list_inference_endpoints # type: ignore[import]
if _is_huggingface_hub(self.llm) or ( if _is_huggingface_hub(self.llm) or (
@ -770,9 +770,9 @@ class ChatHuggingFace(BaseChatModel):
): ):
self.model_id = self.llm.repo_id self.model_id = self.llm.repo_id
return return
elif _is_huggingface_textgen_inference(self.llm): if _is_huggingface_textgen_inference(self.llm):
endpoint_url: Optional[str] = self.llm.inference_server_url endpoint_url: Optional[str] = self.llm.inference_server_url
elif _is_huggingface_pipeline(self.llm): if _is_huggingface_pipeline(self.llm):
from transformers import AutoTokenizer # type: ignore[import] from transformers import AutoTokenizer # type: ignore[import]
self.model_id = self.model_id or self.llm.model_id self.model_id = self.model_id or self.llm.model_id
@ -782,29 +782,29 @@ class ChatHuggingFace(BaseChatModel):
else self.tokenizer else self.tokenizer
) )
return return
elif _is_huggingface_endpoint(self.llm): if _is_huggingface_endpoint(self.llm):
self.model_id = self.llm.repo_id or self.llm.model self.model_id = self.llm.repo_id or self.llm.model
return return
else: endpoint_url = self.llm.endpoint_url
endpoint_url = self.llm.endpoint_url
available_endpoints = list_inference_endpoints("*") available_endpoints = list_inference_endpoints("*")
for endpoint in available_endpoints: for endpoint in available_endpoints:
if endpoint.url == endpoint_url: if endpoint.url == endpoint_url:
self.model_id = endpoint.repository self.model_id = endpoint.repository
if not self.model_id: if not self.model_id:
raise ValueError( msg = (
"Failed to resolve model_id:" "Failed to resolve model_id:"
f"Could not find model id for inference server: {endpoint_url}" f"Could not find model id for inference server: {endpoint_url}"
"Make sure that your Hugging Face token has access to the endpoint." "Make sure that your Hugging Face token has access to the endpoint."
) )
raise ValueError(msg)
def bind_tools( def bind_tools(
self, self,
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]], tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
*, *,
tool_choice: Optional[ tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required"], bool] Union[dict, str, Literal["auto", "none", "required"], bool] # noqa: PYI051
] = None, ] = None,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]: ) -> Runnable[LanguageModelInput, BaseMessage]:
@ -823,15 +823,16 @@ class ChatHuggingFace(BaseChatModel):
{"type": "function", "function": {"name": <<tool_name>>}}. {"type": "function", "function": {"name": <<tool_name>>}}.
**kwargs: Any additional parameters to pass to the **kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor. :class:`~langchain.runnable.Runnable` constructor.
"""
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools] formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
if tool_choice is not None and tool_choice: if tool_choice is not None and tool_choice:
if len(formatted_tools) != 1: if len(formatted_tools) != 1:
raise ValueError( msg = (
"When specifying `tool_choice`, you must provide exactly one " "When specifying `tool_choice`, you must provide exactly one "
f"tool. Received {len(formatted_tools)} tools." f"tool. Received {len(formatted_tools)} tools."
) )
raise ValueError(msg)
if isinstance(tool_choice, str): if isinstance(tool_choice, str):
if tool_choice not in ("auto", "none", "required"): if tool_choice not in ("auto", "none", "required"):
tool_choice = { tool_choice = {
@ -845,15 +846,17 @@ class ChatHuggingFace(BaseChatModel):
formatted_tools[0]["function"]["name"] formatted_tools[0]["function"]["name"]
!= tool_choice["function"]["name"] != tool_choice["function"]["name"]
): ):
raise ValueError( msg = (
f"Tool choice {tool_choice} was specified, but the only " f"Tool choice {tool_choice} was specified, but the only "
f"provided tool was {formatted_tools[0]['function']['name']}." f"provided tool was {formatted_tools[0]['function']['name']}."
) )
raise ValueError(msg)
else: else:
raise ValueError( msg = (
f"Unrecognized tool_choice type. Expected str, bool or dict. " f"Unrecognized tool_choice type. Expected str, bool or dict. "
f"Received: {tool_choice}" f"Received: {tool_choice}"
) )
raise ValueError(msg)
kwargs["tool_choice"] = tool_choice kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs) return super().bind(tools=formatted_tools, **kwargs)
@ -892,6 +895,11 @@ class ChatHuggingFace(BaseChatModel):
will be caught and returned as well. The final output is always a dict will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error". with keys "raw", "parsed", and "parsing_error".
kwargs:
Additional parameters to pass to the underlying LLM's
:meth:`langchain_core.language_models.chat.BaseChatModel.bind`
method, such as `response_format` or `ls_structured_output_format`.
Returns: Returns:
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`. A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
@ -908,14 +916,16 @@ class ChatHuggingFace(BaseChatModel):
""" # noqa: E501 """ # noqa: E501
_ = kwargs.pop("strict", None) _ = kwargs.pop("strict", None)
if kwargs: if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}") msg = f"Received unsupported arguments {kwargs}"
raise ValueError(msg)
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema) is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
if method == "function_calling": if method == "function_calling":
if schema is None: if schema is None:
raise ValueError( msg = (
"schema must be specified when method is 'function_calling'. " "schema must be specified when method is 'function_calling'. "
"Received None." "Received None."
) )
raise ValueError(msg)
formatted_tool = convert_to_openai_tool(schema) formatted_tool = convert_to_openai_tool(schema)
tool_name = formatted_tool["function"]["name"] tool_name = formatted_tool["function"]["name"]
llm = self.bind_tools( llm = self.bind_tools(
@ -927,19 +937,18 @@ class ChatHuggingFace(BaseChatModel):
}, },
) )
if is_pydantic_schema: if is_pydantic_schema:
raise NotImplementedError( msg = "Pydantic schema is not supported for function calling"
"Pydantic schema is not supported for function calling" raise NotImplementedError(msg)
) output_parser: Union[JsonOutputKeyToolsParser, JsonOutputParser] = (
else: JsonOutputKeyToolsParser(key_name=tool_name, first_tool_only=True)
output_parser: Union[JsonOutputKeyToolsParser, JsonOutputParser] = ( )
JsonOutputKeyToolsParser(key_name=tool_name, first_tool_only=True)
)
elif method == "json_schema": elif method == "json_schema":
if schema is None: if schema is None:
raise ValueError( msg = (
"schema must be specified when method is 'json_schema'. " "schema must be specified when method is 'json_schema'. "
"Received None." "Received None."
) )
raise ValueError(msg)
formatted_schema = convert_to_json_schema(schema) formatted_schema = convert_to_json_schema(schema)
llm = self.bind( llm = self.bind(
response_format={"type": "json_object", "schema": formatted_schema}, response_format={"type": "json_object", "schema": formatted_schema},
@ -963,10 +972,11 @@ class ChatHuggingFace(BaseChatModel):
JsonOutputKeyToolsParser, JsonOutputParser JsonOutputKeyToolsParser, JsonOutputParser
] = JsonOutputParser() # type: ignore[arg-type] ] = JsonOutputParser() # type: ignore[arg-type]
else: else:
raise ValueError( msg = (
f"Unrecognized method argument. Expected one of 'function_calling' or " f"Unrecognized method argument. Expected one of 'function_calling' or "
f"'json_mode'. Received: '{method}'" f"'json_mode'. Received: '{method}'"
) )
raise ValueError(msg)
if include_raw: if include_raw:
parser_assign = RunnablePassthrough.assign( parser_assign = RunnablePassthrough.assign(
@ -977,8 +987,7 @@ class ChatHuggingFace(BaseChatModel):
[parser_none], exception_key="parsing_error" [parser_none], exception_key="parsing_error"
) )
return RunnableMap(raw=llm) | parser_with_fallback return RunnableMap(raw=llm) | parser_with_fallback
else: return llm | output_parser
return llm | output_parser
def _create_message_dicts( def _create_message_dicts(
self, messages: list[BaseMessage], stop: Optional[list[str]] self, messages: list[BaseMessage], stop: Optional[list[str]]
@ -991,8 +1000,7 @@ class ChatHuggingFace(BaseChatModel):
@property @property
def _default_params(self) -> dict[str, Any]: def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling Hugging Face """Get default parameters for calling Hugging Face Inference Providers API."""
Inference Providers API."""
params = { params = {
"model": self.model_id, "model": self.model_id,
"stream": self.streaming, "stream": self.streaming,

View File

@ -1,9 +1,11 @@
from __future__ import annotations
from typing import Any, Optional from typing import Any, Optional
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from ..utils.import_utils import ( from langchain_huggingface.utils.import_utils import (
IMPORT_ERROR, IMPORT_ERROR,
is_ipex_available, is_ipex_available,
is_optimum_intel_available, is_optimum_intel_available,
@ -33,12 +35,13 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs encode_kwargs=encode_kwargs
) )
""" """
model_name: str = Field(default=DEFAULT_MODEL_NAME, alias="model") model_name: str = Field(default=DEFAULT_MODEL_NAME, alias="model")
"""Model name to use.""" """Model name to use."""
cache_folder: Optional[str] = None cache_folder: Optional[str] = None
"""Path to store models. """Path to store models.
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable.""" Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
model_kwargs: dict[str, Any] = Field(default_factory=dict) model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass to the Sentence Transformer model, such as `device`, """Keyword arguments to pass to the Sentence Transformer model, such as `device`,
@ -46,12 +49,12 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer""" See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer"""
encode_kwargs: dict[str, Any] = Field(default_factory=dict) encode_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass when calling the `encode` method for the documents of """Keyword arguments to pass when calling the `encode` method for the documents of
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`, the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
`precision`, `normalize_embeddings`, and more. `precision`, `normalize_embeddings`, and more.
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode""" See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""
query_encode_kwargs: dict[str, Any] = Field(default_factory=dict) query_encode_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass when calling the `encode` method for the query of """Keyword arguments to pass when calling the `encode` method for the query of
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`, the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
`precision`, `normalize_embeddings`, and more. `precision`, `normalize_embeddings`, and more.
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode""" See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""
multi_process: bool = False multi_process: bool = False
@ -65,24 +68,25 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
try: try:
import sentence_transformers # type: ignore[import] import sentence_transformers # type: ignore[import]
except ImportError as exc: except ImportError as exc:
raise ImportError( msg = (
"Could not import sentence_transformers python package. " "Could not import sentence_transformers python package. "
"Please install it with `pip install sentence-transformers`." "Please install it with `pip install sentence-transformers`."
) from exc )
raise ImportError(msg) from exc
if self.model_kwargs.get("backend", "torch") == "ipex": if self.model_kwargs.get("backend", "torch") == "ipex":
if not is_optimum_intel_available() or not is_ipex_available(): if not is_optimum_intel_available() or not is_ipex_available():
raise ImportError( msg = f'Backend: ipex {IMPORT_ERROR.format("optimum[ipex]")}'
f'Backend: ipex {IMPORT_ERROR.format("optimum[ipex]")}' raise ImportError(msg)
)
if is_optimum_intel_version("<", _MIN_OPTIMUM_VERSION): if is_optimum_intel_version("<", _MIN_OPTIMUM_VERSION):
raise ImportError( msg = (
f"Backend: ipex requires optimum-intel>=" f"Backend: ipex requires optimum-intel>="
f"{_MIN_OPTIMUM_VERSION}. You can install it with pip: " f"{_MIN_OPTIMUM_VERSION}. You can install it with pip: "
"`pip install --upgrade --upgrade-strategy eager " "`pip install --upgrade --upgrade-strategy eager "
"`optimum[ipex]`." "`optimum[ipex]`."
) )
raise ImportError(msg)
from optimum.intel import IPEXSentenceTransformer # type: ignore[import] from optimum.intel import IPEXSentenceTransformer # type: ignore[import]
@ -104,21 +108,21 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
def _embed( def _embed(
self, texts: list[str], encode_kwargs: dict[str, Any] self, texts: list[str], encode_kwargs: dict[str, Any]
) -> list[list[float]]: ) -> list[list[float]]:
""" """Embed a text using the HuggingFace transformer model.
Embed a text using the HuggingFace transformer model.
Args: Args:
texts: The list of texts to embed. texts: The list of texts to embed.
encode_kwargs: Keyword arguments to pass when calling the encode_kwargs: Keyword arguments to pass when calling the
`encode` method for the documents of the SentenceTransformer `encode` method for the documents of the SentenceTransformer
encode method. encode method.
Returns: Returns:
List of embeddings, one for each text. List of embeddings, one for each text.
""" """
import sentence_transformers # type: ignore[import] import sentence_transformers # type: ignore[import]
texts = list(map(lambda x: x.replace("\n", " "), texts)) texts = [x.replace("\n", " ") for x in texts]
if self.multi_process: if self.multi_process:
pool = self._client.start_multi_process_pool() pool = self._client.start_multi_process_pool()
embeddings = self._client.encode_multi_process(texts, pool) embeddings = self._client.encode_multi_process(texts, pool)
@ -127,16 +131,17 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
embeddings = self._client.encode( embeddings = self._client.encode(
texts, texts,
show_progress_bar=self.show_progress, show_progress_bar=self.show_progress,
**encode_kwargs, # type: ignore **encode_kwargs,
) )
if isinstance(embeddings, list): if isinstance(embeddings, list):
raise TypeError( msg = (
"Expected embeddings to be a Tensor or a numpy array, " "Expected embeddings to be a Tensor or a numpy array, "
"got a list instead." "got a list instead."
) )
raise TypeError(msg)
return embeddings.tolist() return embeddings.tolist() # type: ignore[return-type]
def embed_documents(self, texts: list[str]) -> list[list[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Compute doc embeddings using a HuggingFace transformer model. """Compute doc embeddings using a HuggingFace transformer model.
@ -146,6 +151,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
Returns: Returns:
List of embeddings, one for each text. List of embeddings, one for each text.
""" """
return self._embed(texts, self.encode_kwargs) return self._embed(texts, self.encode_kwargs)
@ -157,6 +163,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
Returns: Returns:
Embeddings for the text. Embeddings for the text.
""" """
embed_kwargs = ( embed_kwargs = (
self.query_encode_kwargs self.query_encode_kwargs

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import os import os
from typing import Any, Optional from typing import Any, Optional
@ -27,6 +29,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
task="feature-extraction", task="feature-extraction",
huggingfacehub_api_token="my-api-key", huggingfacehub_api_token="my-api-key",
) )
""" """
client: Any = None #: :meta private: client: Any = None #: :meta private:
@ -35,7 +38,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
"""Model name to use.""" """Model name to use."""
provider: Optional[str] = None provider: Optional[str] = None
"""Name of the provider to use for inference with the model specified in """Name of the provider to use for inference with the model specified in
``repo_id``. e.g. "sambanova". if not specified, defaults to HF Inference API. ``repo_id``. e.g. "sambanova". if not specified, defaults to HF Inference API.
available providers can be found in the [huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/guides/inference#supported-providers-and-tasks).""" available providers can be found in the [huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/guides/inference#supported-providers-and-tasks)."""
repo_id: Optional[str] = None repo_id: Optional[str] = None
"""Huggingfacehub repository id, for backward compatibility.""" """Huggingfacehub repository id, for backward compatibility."""
@ -87,18 +90,20 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
) )
if self.task not in VALID_TASKS: if self.task not in VALID_TASKS:
raise ValueError( msg = (
f"Got invalid task {self.task}, " f"Got invalid task {self.task}, "
f"currently only {VALID_TASKS} are supported" f"currently only {VALID_TASKS} are supported"
) )
raise ValueError(msg)
self.client = client self.client = client
self.async_client = async_client self.async_client = async_client
except ImportError: except ImportError:
raise ImportError( msg = (
"Could not import huggingface_hub python package. " "Could not import huggingface_hub python package. "
"Please install it with `pip install huggingface_hub`." "Please install it with `pip install huggingface_hub`."
) )
raise ImportError(msg)
return self return self
def embed_documents(self, texts: list[str]) -> list[list[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
@ -109,6 +114,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
Returns: Returns:
List of embeddings, one for each text. List of embeddings, one for each text.
""" """
# replace newlines, which can negatively affect performance. # replace newlines, which can negatively affect performance.
texts = [text.replace("\n", " ") for text in texts] texts = [text.replace("\n", " ") for text in texts]
@ -125,6 +131,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
Returns: Returns:
List of embeddings, one for each text. List of embeddings, one for each text.
""" """
# replace newlines, which can negatively affect performance. # replace newlines, which can negatively affect performance.
texts = [text.replace("\n", " ") for text in texts] texts = [text.replace("\n", " ") for text in texts]
@ -142,9 +149,9 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
Returns: Returns:
Embeddings for the text. Embeddings for the text.
""" """
response = self.embed_documents([text])[0] return self.embed_documents([text])[0]
return response
async def aembed_query(self, text: str) -> list[float]: async def aembed_query(self, text: str) -> list[float]:
"""Async Call to HuggingFaceHub's embedding endpoint for embedding query text. """Async Call to HuggingFaceHub's embedding endpoint for embedding query text.
@ -154,6 +161,6 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
Returns: Returns:
Embeddings for the text. Embeddings for the text.
""" """
response = (await self.aembed_documents([text]))[0] return (await self.aembed_documents([text]))[0]
return response

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import inspect import inspect
import logging import logging
import os import os
@ -25,8 +27,7 @@ VALID_TASKS = (
class HuggingFaceEndpoint(LLM): class HuggingFaceEndpoint(LLM):
""" """Hugging Face Endpoint. This works with any model that supports text generation (i.e. text completion) task.
Hugging Face Endpoint. This works with any model that supports text generation (i.e. text completion) task.
To use this class, you should have installed the ``huggingface_hub`` package, and To use this class, you should have installed the ``huggingface_hub`` package, and
the environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, the environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token,
@ -75,6 +76,7 @@ class HuggingFaceEndpoint(LLM):
huggingfacehub_api_token="my-api-key" huggingfacehub_api_token="my-api-key"
) )
print(llm.invoke("What is Deep Learning?")) print(llm.invoke("What is Deep Learning?"))
""" # noqa: E501 """ # noqa: E501
endpoint_url: Optional[str] = None endpoint_url: Optional[str] = None
@ -147,7 +149,8 @@ class HuggingFaceEndpoint(LLM):
extra = values.get("model_kwargs", {}) extra = values.get("model_kwargs", {})
for field_name in list(values): for field_name in list(values):
if field_name in extra: if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.") msg = f"Found {field_name} supplied twice."
raise ValueError(msg)
if field_name not in all_required_field_names: if field_name not in all_required_field_names:
logger.warning( logger.warning(
f"""WARNING! {field_name} is not default parameter. f"""WARNING! {field_name} is not default parameter.
@ -158,10 +161,11 @@ class HuggingFaceEndpoint(LLM):
invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs: if invalid_model_kwargs:
raise ValueError( msg = (
f"Parameters {invalid_model_kwargs} should be specified explicitly. " f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter." f"Instead they were passed in as part of `model_kwargs` parameter."
) )
raise ValueError(msg)
values["model_kwargs"] = extra values["model_kwargs"] = extra
@ -185,18 +189,20 @@ class HuggingFaceEndpoint(LLM):
repo_id = values.get("repo_id") repo_id = values.get("repo_id")
if sum([bool(model), bool(endpoint_url), bool(repo_id)]) > 1: if sum([bool(model), bool(endpoint_url), bool(repo_id)]) > 1:
raise ValueError( msg = (
"Please specify either a `model` OR an `endpoint_url` OR a `repo_id`," "Please specify either a `model` OR an `endpoint_url` OR a `repo_id`,"
"not more than one." "not more than one."
) )
raise ValueError(msg)
values["model"] = ( values["model"] = (
model or endpoint_url or repo_id or os.environ.get("HF_INFERENCE_ENDPOINT") model or endpoint_url or repo_id or os.environ.get("HF_INFERENCE_ENDPOINT")
) )
if not values["model"]: if not values["model"]:
raise ValueError( msg = (
"Please specify a `model` or an `endpoint_url` or a `repo_id` for the " "Please specify a `model` or an `endpoint_url` or a `repo_id` for the "
"model." "model."
) )
raise ValueError(msg)
return values return values
@model_validator(mode="after") @model_validator(mode="after")
@ -274,12 +280,10 @@ class HuggingFaceEndpoint(LLM):
"""Get the identifying parameters.""" """Get the identifying parameters."""
_model_kwargs = self.model_kwargs or {} _model_kwargs = self.model_kwargs or {}
return { return {
**{ "endpoint_url": self.endpoint_url,
"endpoint_url": self.endpoint_url, "task": self.task,
"task": self.task, "provider": self.provider,
"provider": self.provider, "model_kwargs": _model_kwargs,
},
**{"model_kwargs": _model_kwargs},
} }
@property @property
@ -310,19 +314,19 @@ class HuggingFaceEndpoint(LLM):
): ):
completion += chunk.text completion += chunk.text
return completion return completion
else:
response_text = self.client.text_generation(
prompt=prompt,
model=self.model,
**invocation_params,
)
# Maybe the generation has stopped at one of the stop sequences: response_text = self.client.text_generation(
# then we remove this stop sequence from the end of the generated text prompt=prompt,
for stop_seq in invocation_params["stop"]: model=self.model,
if response_text[-len(stop_seq) :] == stop_seq: **invocation_params,
response_text = response_text[: -len(stop_seq)] )
return response_text
# Maybe the generation has stopped at one of the stop sequences:
# then we remove this stop sequence from the end of the generated text
for stop_seq in invocation_params["stop"]:
if response_text[-len(stop_seq) :] == stop_seq:
response_text = response_text[: -len(stop_seq)]
return response_text
async def _acall( async def _acall(
self, self,
@ -339,20 +343,20 @@ class HuggingFaceEndpoint(LLM):
): ):
completion += chunk.text completion += chunk.text
return completion return completion
else:
response_text = await self.async_client.text_generation(
prompt=prompt,
**invocation_params,
model=self.model,
stream=False,
)
# Maybe the generation has stopped at one of the stop sequences: response_text = await self.async_client.text_generation(
# then remove this stop sequence from the end of the generated text prompt=prompt,
for stop_seq in invocation_params["stop"]: **invocation_params,
if response_text[-len(stop_seq) :] == stop_seq: model=self.model,
response_text = response_text[: -len(stop_seq)] stream=False,
return response_text )
# Maybe the generation has stopped at one of the stop sequences:
# then remove this stop sequence from the end of the generated text
for stop_seq in invocation_params["stop"]:
if response_text[-len(stop_seq) :] == stop_seq:
response_text = response_text[: -len(stop_seq)]
return response_text
def _stream( def _stream(
self, self,

View File

@ -10,7 +10,7 @@ from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from pydantic import ConfigDict, model_validator from pydantic import ConfigDict, model_validator
from ..utils.import_utils import ( from langchain_huggingface.utils.import_utils import (
IMPORT_ERROR, IMPORT_ERROR,
is_ipex_available, is_ipex_available,
is_openvino_available, is_openvino_available,
@ -87,7 +87,7 @@ class HuggingFacePipeline(BaseLLM):
def pre_init_validator(cls, values: dict[str, Any]) -> dict[str, Any]: def pre_init_validator(cls, values: dict[str, Any]) -> dict[str, Any]:
"""Ensure model_id is set either by pipeline or user input.""" """Ensure model_id is set either by pipeline or user input."""
if "model_id" not in values: if "model_id" not in values:
if "pipeline" in values and values["pipeline"]: if values.get("pipeline"):
values["model_id"] = values["pipeline"].model.name_or_path values["model_id"] = values["pipeline"].model.name_or_path
else: else:
values["model_id"] = DEFAULT_MODEL_ID values["model_id"] = DEFAULT_MODEL_ID
@ -116,34 +116,38 @@ class HuggingFacePipeline(BaseLLM):
from transformers import pipeline as hf_pipeline # type: ignore[import] from transformers import pipeline as hf_pipeline # type: ignore[import]
except ImportError: except ImportError:
raise ValueError( msg = (
"Could not import transformers python package. " "Could not import transformers python package. "
"Please install it with `pip install transformers`." "Please install it with `pip install transformers`."
) )
raise ValueError(msg)
_model_kwargs = model_kwargs.copy() if model_kwargs else {} _model_kwargs = model_kwargs.copy() if model_kwargs else {}
if device_map is not None: if device_map is not None:
if device is not None: if device is not None:
raise ValueError( msg = (
"Both `device` and `device_map` are specified. " "Both `device` and `device_map` are specified. "
"`device` will override `device_map`. " "`device` will override `device_map`. "
"You will most likely encounter unexpected behavior." "You will most likely encounter unexpected behavior."
"Please remove `device` and keep " "Please remove `device` and keep "
"`device_map`." "`device_map`."
) )
raise ValueError(msg)
if "device_map" in _model_kwargs: if "device_map" in _model_kwargs:
raise ValueError("`device_map` is already specified in `model_kwargs`.") msg = "`device_map` is already specified in `model_kwargs`."
raise ValueError(msg)
_model_kwargs["device_map"] = device_map _model_kwargs["device_map"] = device_map
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs) tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
if backend in {"openvino", "ipex"}: if backend in {"openvino", "ipex"}:
if task not in VALID_TASKS: if task not in VALID_TASKS:
raise ValueError( msg = (
f"Got invalid task {task}, " f"Got invalid task {task}, "
f"currently only {VALID_TASKS} are supported" f"currently only {VALID_TASKS} are supported"
) )
raise ValueError(msg)
err_msg = f'Backend: {backend} {IMPORT_ERROR.format(f"optimum[{backend}]")}' err_msg = f'Backend: {backend} {IMPORT_ERROR.format(f"optimum[{backend}]")}'
if not is_optimum_intel_available(): if not is_optimum_intel_available():
@ -156,12 +160,13 @@ class HuggingFacePipeline(BaseLLM):
else _MIN_OPTIMUM_VERSION else _MIN_OPTIMUM_VERSION
) )
if is_optimum_intel_version("<", min_optimum_version): if is_optimum_intel_version("<", min_optimum_version):
raise ImportError( msg = (
f"Backend: {backend} requires optimum-intel>=" f"Backend: {backend} requires optimum-intel>="
f"{min_optimum_version}. You can install it with pip: " f"{min_optimum_version}. You can install it with pip: "
"`pip install --upgrade --upgrade-strategy eager " "`pip install --upgrade --upgrade-strategy eager "
f"`optimum[{backend}]`." f"`optimum[{backend}]`."
) )
raise ImportError(msg)
if backend == "openvino": if backend == "openvino":
if not is_openvino_available(): if not is_openvino_available():
@ -240,10 +245,11 @@ class HuggingFacePipeline(BaseLLM):
cuda_device_count = torch.cuda.device_count() cuda_device_count = torch.cuda.device_count()
if device < -1 or (device >= cuda_device_count): if device < -1 or (device >= cuda_device_count):
raise ValueError( msg = (
f"Got device=={device}, " f"Got device=={device}, "
f"device is required to be within [-1, {cuda_device_count})" f"device is required to be within [-1, {cuda_device_count})"
) )
raise ValueError(msg)
if device_map is not None and device < 0: if device_map is not None and device < 0:
device = None device = None
if device is not None and device < 0 and cuda_device_count > 0: if device is not None and device < 0 and cuda_device_count > 0:
@ -271,10 +277,11 @@ class HuggingFacePipeline(BaseLLM):
**_pipeline_kwargs, **_pipeline_kwargs,
) )
if pipeline.task not in VALID_TASKS: if pipeline.task not in VALID_TASKS:
raise ValueError( msg = (
f"Got invalid task {pipeline.task}, " f"Got invalid task {pipeline.task}, "
f"currently only {VALID_TASKS} are supported" f"currently only {VALID_TASKS} are supported"
) )
raise ValueError(msg)
return cls( return cls(
pipeline=pipeline, pipeline=pipeline,
model_id=model_id, model_id=model_id,
@ -324,21 +331,22 @@ class HuggingFacePipeline(BaseLLM):
# if model returns multiple generations, pick the top one # if model returns multiple generations, pick the top one
response = response[0] response = response[0]
if self.pipeline.task == "text-generation": if (
text = response["generated_text"] self.pipeline.task == "text-generation"
elif self.pipeline.task == "text2text-generation": or self.pipeline.task == "text2text-generation"
text = response["generated_text"] or self.pipeline.task == "image-text-to-text"
elif self.pipeline.task == "image-text-to-text": ):
text = response["generated_text"] text = response["generated_text"]
elif self.pipeline.task == "summarization": elif self.pipeline.task == "summarization":
text = response["summary_text"] text = response["summary_text"]
elif self.pipeline.task in "translation": elif self.pipeline.task in "translation":
text = response["translation_text"] text = response["translation_text"]
else: else:
raise ValueError( msg = (
f"Got invalid task {self.pipeline.task}, " f"Got invalid task {self.pipeline.task}, "
f"currently only {VALID_TASKS} are supported" f"currently only {VALID_TASKS} are supported"
) )
raise ValueError(msg)
if skip_prompt: if skip_prompt:
text = text[len(batch_prompts[j]) :] text = text[len(batch_prompts[j]) :]
# Append the processed text to results # Append the processed text to results
@ -378,10 +386,7 @@ class HuggingFacePipeline(BaseLLM):
scores: torch.FloatTensor, scores: torch.FloatTensor,
**kwargs: Any, **kwargs: Any,
) -> bool: ) -> bool:
for stop_id in stopping_ids_list: return any(input_ids[0][-1] == stop_id for stop_id in stopping_ids_list)
if input_ids[0][-1] == stop_id:
return True
return False
stopping_criteria = StoppingCriteriaList([StopOnTokens()]) stopping_criteria = StoppingCriteriaList([StopOnTokens()])

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import importlib.metadata import importlib.metadata
import importlib.util import importlib.util
import operator as op import operator as op
@ -46,8 +48,7 @@ def compare_versions(
operation: str, operation: str,
requirement_version: str, requirement_version: str,
) -> bool: ) -> bool:
""" """Compare a library version to some requirement using a given operation.
Compare a library version to some requirement using a given operation.
Arguments: Arguments:
library_or_version (`str` or `packaging.version.Version`): library_or_version (`str` or `packaging.version.Version`):
@ -56,12 +57,14 @@ def compare_versions(
A string representation of an operator, such as `">"` or `"<="`. A string representation of an operator, such as `">"` or `"<="`.
requirement_version (`str`): requirement_version (`str`):
The version to compare the library version against The version to compare the library version against
""" """
if operation not in STR_OPERATION_TO_FUNC.keys(): if operation not in STR_OPERATION_TO_FUNC:
raise ValueError( msg = (
f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}" f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}"
f", received {operation}" f", received {operation}"
) )
raise ValueError(msg)
if isinstance(library_or_version, str): if isinstance(library_or_version, str):
library_or_version = version.parse( library_or_version = version.parse(
importlib.metadata.version(library_or_version) importlib.metadata.version(library_or_version)
@ -88,9 +91,7 @@ def is_openvino_available() -> bool:
def is_optimum_version(operation: str, reference_version: str) -> bool: def is_optimum_version(operation: str, reference_version: str) -> bool:
""" """Compare the current Optimum version to a given reference with an operation."""
Compare the current Optimum version to a given reference with an operation.
"""
if not _optimum_version: if not _optimum_version:
return False return False
return compare_versions( return compare_versions(
@ -99,9 +100,7 @@ def is_optimum_version(operation: str, reference_version: str) -> bool:
def is_optimum_intel_version(operation: str, reference_version: str) -> bool: def is_optimum_intel_version(operation: str, reference_version: str) -> bool:
""" """Compare current Optimum Intel version to a given reference with an operation."""
Compare the current Optimum Intel version to a given reference with an operation.
"""
if not _optimum_intel_version: if not _optimum_intel_version:
return False return False
return compare_versions( return compare_versions(

View File

@ -58,8 +58,60 @@ disallow_untyped_defs = "True"
target-version = "py39" target-version = "py39"
[tool.ruff.lint] [tool.ruff.lint]
select = ["E", "F", "I", "T201", "UP", "S"] select = [
ignore = ["UP007"] "A", # flake8-builtins
"ASYNC", # flake8-async
"C4", # flake8-comprehensions
"COM", # flake8-commas
"D", # pydocstyle
"DOC", # pydoclint
"E", # pycodestyle error
"EM", # flake8-errmsg
"F", # pyflakes
"FA", # flake8-future-annotations
"FBT", # flake8-boolean-trap
"FLY", # flake8-flynt
"I", # isort
"ICN", # flake8-import-conventions
"INT", # flake8-gettext
"ISC", # isort-comprehensions
"PGH", # pygrep-hooks
"PIE", # flake8-pie
"PERF", # flake8-perf
"PYI", # flake8-pyi
"Q", # flake8-quotes
"RET", # flake8-return
"RSE", # flake8-rst-docstrings
"RUF", # ruff
"S", # flake8-bandit
"SLF", # flake8-self
"SLOT", # flake8-slots
"SIM", # flake8-simplify
"T10", # flake8-debugger
"T20", # flake8-print
"TID", # flake8-tidy-imports
"UP", # pyupgrade
"W", # pycodestyle warning
"YTT", # flake8-2020
]
ignore = [
"D100", # pydocstyle: Missing docstring in public module
"D101", # pydocstyle: Missing docstring in public class
"D102", # pydocstyle: Missing docstring in public method
"D103", # pydocstyle: Missing docstring in public function
"D104", # pydocstyle: Missing docstring in public package
"D105", # pydocstyle: Missing docstring in magic method
"D107", # pydocstyle: Missing docstring in __init__
"D407", # pydocstyle: Missing-dashed-underline-after-section
"COM812", # Messes with the formatter
"ISC001", # Messes with the formatter
"PERF203", # Rarely useful
"S112", # Rarely useful
"RUF012", # Doesn't play well with Pydantic
"SLF001", # Private member access
"UP007", # pyupgrade: non-pep604-annotation-union
]
[tool.coverage.run] [tool.coverage.run]
omit = ["tests/*"] omit = ["tests/*"]

View File

@ -4,4 +4,3 @@ import pytest # type: ignore[import-not-found, import-not-found]
@pytest.mark.compile @pytest.mark.compile
def test_placeholder() -> None: def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests.""" """Used for compiling integration tests without running any real tests."""
pass

View File

@ -1,4 +1,4 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests."""
import pytest import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel

View File

@ -31,8 +31,7 @@ def mock_llm() -> Mock:
"langchain_huggingface.chat_models.huggingface.ChatHuggingFace._resolve_model_id" "langchain_huggingface.chat_models.huggingface.ChatHuggingFace._resolve_model_id"
) )
def chat_hugging_face(mock_resolve_id: Any, mock_llm: Any) -> ChatHuggingFace: def chat_hugging_face(mock_resolve_id: Any, mock_llm: Any) -> ChatHuggingFace:
chat_hf = ChatHuggingFace(llm=mock_llm, tokenizer=MagicMock()) return ChatHuggingFace(llm=mock_llm, tokenizer=MagicMock())
return chat_hf
def test_create_chat_result(chat_hugging_face: Any) -> None: def test_create_chat_result(chat_hugging_face: Any) -> None:

View File

@ -6,8 +6,7 @@ DEFAULT_MODEL_ID = "gpt2"
def test_initialization_default() -> None: def test_initialization_default() -> None:
"""Test default initialization""" """Test default initialization."""
llm = HuggingFacePipeline() llm = HuggingFacePipeline()
assert llm.model_id == DEFAULT_MODEL_ID assert llm.model_id == DEFAULT_MODEL_ID
@ -15,8 +14,7 @@ def test_initialization_default() -> None:
@patch("transformers.pipeline") @patch("transformers.pipeline")
def test_initialization_with_pipeline(mock_pipeline: MagicMock) -> None: def test_initialization_with_pipeline(mock_pipeline: MagicMock) -> None:
"""Test initialization with a pipeline object""" """Test initialization with a pipeline object."""
mock_pipe = MagicMock() mock_pipe = MagicMock()
mock_pipe.model.name_or_path = "mock-model-id" mock_pipe.model.name_or_path = "mock-model-id"
mock_pipeline.return_value = mock_pipe mock_pipeline.return_value = mock_pipe
@ -32,8 +30,7 @@ def test_initialization_with_pipeline(mock_pipeline: MagicMock) -> None:
def test_initialization_with_from_model_id( def test_initialization_with_from_model_id(
mock_pipeline: MagicMock, mock_model: MagicMock, mock_tokenizer: MagicMock mock_pipeline: MagicMock, mock_model: MagicMock, mock_tokenizer: MagicMock
) -> None: ) -> None:
"""Test initialization with the from_model_id method""" """Test initialization with the from_model_id method."""
mock_tokenizer.return_value = MagicMock(pad_token_id=0) mock_tokenizer.return_value = MagicMock(pad_token_id=0)
mock_model.return_value = MagicMock() mock_model.return_value = MagicMock()