huggingface[patch]: ruff fixes and rules (#31912)

* bump ruff deps
* add more thorough ruff rules
* fix said rules
This commit is contained in:
Mason Daugherty 2025-07-08 10:07:57 -04:00 committed by GitHub
parent 06ab2972e3
commit 750721b4c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 283 additions and 196 deletions

View File

@ -87,7 +87,16 @@ select = [
"W", # pycodestyle warning
"YTT", # flake8-2020
]
ignore = ["D100", "D101", "D102", "D103", "D104", "D105", "D107"]
ignore = [
"D100", # pydocstyle: Missing docstring in public module
"D101", # pydocstyle: Missing docstring in public class
"D102", # pydocstyle: Missing docstring in public method
"D103", # pydocstyle: Missing docstring in public function
"D104", # pydocstyle: Missing docstring in public package
"D105", # pydocstyle: Missing docstring in magic method
"D107", # pydocstyle: Missing docstring in __init__
"D407", # pydocstyle: Missing-dashed-underline-after-section
]
pyupgrade.keep-runtime-typing = true
[tool.mypy]

View File

@ -180,13 +180,14 @@ select = [
"YTT", # flake8-2020
]
ignore = [
"D100", # Missing docstring in public module
"D101", # Missing docstring in public class
"D102", # Missing docstring in public method
"D103", # Missing docstring in public function
"D104", # Missing docstring in public package
"D105", # Missing docstring in magic method
"D107", # Missing docstring in __init__
"D100", # pydocstyle: Missing docstring in public module
"D101", # pydocstyle: Missing docstring in public class
"D102", # pydocstyle: Missing docstring in public method
"D103", # pydocstyle: Missing docstring in public function
"D104", # pydocstyle: Missing docstring in public package
"D105", # pydocstyle: Missing docstring in magic method
"D107", # pydocstyle: Missing docstring in __init__
"D407", # pydocstyle: Missing-dashed-underline-after-section
"COM812", # Messes with the formatter
"ISC001", # Messes with the formatter
"PERF203", # Rarely useful

View File

@ -97,13 +97,14 @@ select = [
"YTT", # flake8-2020
]
ignore = [
"D100", # Missing docstring in public module
"D101", # Missing docstring in public class
"D102", # Missing docstring in public method
"D103", # Missing docstring in public function
"D104", # Missing docstring in public package
"D105", # Missing docstring in magic method
"D107", # Missing docstring in __init__
"D100", # pydocstyle: Missing docstring in public module
"D101", # pydocstyle: Missing docstring in public class
"D102", # pydocstyle: Missing docstring in public method
"D103", # pydocstyle: Missing docstring in public function
"D104", # pydocstyle: Missing docstring in public package
"D105", # pydocstyle: Missing docstring in magic method
"D107", # pydocstyle: Missing docstring in __init__
"D407", # pydocstyle: Missing-dashed-underline-after-section
"D214", # Section over-indented, doesn't play well with reStructuredText
"COM812", # Messes with the formatter
"ISC001", # Messes with the formatter

View File

@ -343,7 +343,6 @@ def function() -> Callable:
"""Dummy function.
Args:
----
arg1: foo
arg2: one of 'bar', 'baz'

View File

@ -1,5 +1,8 @@
"""Hugging Face Chat Wrapper."""
from __future__ import annotations
import contextlib
import json
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
from dataclasses import dataclass
@ -58,8 +61,8 @@ from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
from ..llms.huggingface_endpoint import HuggingFaceEndpoint
from ..llms.huggingface_pipeline import HuggingFacePipeline
from langchain_huggingface.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain_huggingface.llms.huggingface_pipeline import HuggingFacePipeline
@dataclass
@ -111,6 +114,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
Returns:
The dictionary.
"""
message_dict: dict[str, Any]
if isinstance(message, ChatMessage):
@ -153,7 +157,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
"tool_call_id": message.tool_call_id,
}
else:
raise TypeError(f"Got unknown type {message}")
msg = f"Got unknown type {message}"
raise TypeError(msg)
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict
@ -167,11 +172,12 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
Returns:
The LangChain message.
"""
role = _dict.get("role")
if role == "user":
return HumanMessage(content=_dict.get("content", ""))
elif role == "assistant":
if role == "assistant":
content = _dict.get("content", "") or ""
additional_kwargs: dict = {}
if function_call := _dict.get("function_call"):
@ -193,13 +199,13 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
elif role == "system":
if role == "system":
return SystemMessage(content=_dict.get("content", ""))
elif role == "function":
if role == "function":
return FunctionMessage(
content=_dict.get("content", ""), name=_dict.get("name", "")
)
elif role == "tool":
if role == "tool":
additional_kwargs = {}
if "name" in _dict:
additional_kwargs["name"] = _dict["name"]
@ -208,8 +214,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
tool_call_id=_dict.get("tool_call_id", ""),
additional_kwargs=additional_kwargs,
)
else:
return ChatMessage(content=_dict.get("content", ""), role=role or "")
return ChatMessage(content=_dict.get("content", ""), role=role or "")
def _is_huggingface_hub(llm: Any) -> bool:
@ -241,7 +246,7 @@ def _convert_chunk_to_message_chunk(
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
for rtc in raw_tool_calls:
try:
with contextlib.suppress(KeyError):
tool_call_chunks.append(
create_tool_call_chunk(
name=rtc["function"].get("name"),
@ -250,11 +255,9 @@ def _convert_chunk_to_message_chunk(
index=rtc.get("index"),
)
)
except KeyError:
pass
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
if role == "assistant" or default_class == AIMessageChunk:
if usage := chunk.get("usage"):
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
@ -271,16 +274,15 @@ def _convert_chunk_to_message_chunk(
tool_call_chunks=tool_call_chunks,
usage_metadata=usage_metadata, # type: ignore[arg-type]
)
elif role == "system" or default_class == SystemMessageChunk:
if role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role == "function" or default_class == FunctionMessageChunk:
if role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"])
elif role == "tool" or default_class == ToolMessageChunk:
if role == "tool" or default_class == ToolMessageChunk:
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
elif role or default_class == ChatMessageChunk:
if role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content) # type: ignore
return default_class(content=content) # type: ignore[call-arg]
def _is_huggingface_textgen_inference(llm: Any) -> bool:
@ -304,7 +306,7 @@ def _is_huggingface_pipeline(llm: Any) -> bool:
class ChatHuggingFace(BaseChatModel):
"""Hugging Face LLM's as ChatModels.
r"""Hugging Face LLM's as ChatModels.
Works with `HuggingFaceTextGenInference`, `HuggingFaceEndpoint`,
`HuggingFaceHub`, and `HuggingFacePipeline` LLMs.
@ -484,7 +486,7 @@ class ChatHuggingFace(BaseChatModel):
"""Whether to return logprobs."""
top_logprobs: Optional[int] = None
"""Number of most likely tokens to return at each token position, each with
an associated log probability. `logprobs` must be set to true
an associated log probability. `logprobs` must be set to true
if this parameter is used."""
logit_bias: Optional[dict[int, int]] = None
"""Modify the likelihood of specified tokens appearing in the completion."""
@ -511,11 +513,12 @@ class ChatHuggingFace(BaseChatModel):
and not _is_huggingface_endpoint(self.llm)
and not _is_huggingface_pipeline(self.llm)
):
raise TypeError(
msg = (
"Expected llm to be one of HuggingFaceTextGenInference, "
"HuggingFaceEndpoint, HuggingFaceHub, HuggingFacePipeline "
f"received {type(self.llm)}"
)
raise TypeError(msg)
return self
def _create_chat_result(self, response: dict) -> ChatResult:
@ -529,7 +532,7 @@ class ChatHuggingFace(BaseChatModel):
"output_tokens": token_usage.get("completion_tokens", 0),
"total_tokens": token_usage.get("total_tokens", 0),
}
generation_info = dict(finish_reason=res.get("finish_reason"))
generation_info = {"finish_reason": res.get("finish_reason")}
if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"]
gen = ChatGeneration(
@ -558,7 +561,7 @@ class ChatHuggingFace(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop)
answer = self.llm.client.chat(messages=message_dicts, **kwargs)
return self._create_chat_result(answer)
elif _is_huggingface_endpoint(self.llm):
if _is_huggingface_endpoint(self.llm):
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
@ -573,18 +576,17 @@ class ChatHuggingFace(BaseChatModel):
}
answer = self.llm.client.chat_completion(messages=message_dicts, **params)
return self._create_chat_result(answer)
else:
llm_input = self._to_chat_prompt(messages)
llm_input = self._to_chat_prompt(messages)
if should_stream:
stream_iter = self.llm._stream(
llm_input, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
llm_result = self.llm._generate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
if should_stream:
stream_iter = self.llm._stream(
llm_input, stop=stop, run_manager=run_manager, **kwargs
)
return self._to_chat_result(llm_result)
return generate_from_stream(stream_iter)
llm_result = self.llm._generate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
return self._to_chat_result(llm_result)
async def _agenerate(
self,
@ -598,7 +600,7 @@ class ChatHuggingFace(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop)
answer = await self.llm.async_client.chat(messages=message_dicts, **kwargs)
return self._create_chat_result(answer)
elif _is_huggingface_endpoint(self.llm):
if _is_huggingface_endpoint(self.llm):
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
@ -616,17 +618,14 @@ class ChatHuggingFace(BaseChatModel):
messages=message_dicts, **params
)
return self._create_chat_result(answer)
elif _is_huggingface_pipeline(self.llm):
raise NotImplementedError(
"async generation is not supported with HuggingFacePipeline"
)
else:
llm_input = self._to_chat_prompt(messages)
llm_result = await self.llm._agenerate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
return self._to_chat_result(llm_result)
if _is_huggingface_pipeline(self.llm):
msg = "async generation is not supported with HuggingFacePipeline"
raise NotImplementedError(msg)
llm_input = self._to_chat_prompt(messages)
llm_result = await self.llm._agenerate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
return self._to_chat_result(llm_result)
def _stream(
self,
@ -721,10 +720,12 @@ class ChatHuggingFace(BaseChatModel):
) -> str:
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
if not messages:
raise ValueError("At least one HumanMessage must be provided!")
msg = "At least one HumanMessage must be provided!"
raise ValueError(msg)
if not isinstance(messages[-1], HumanMessage):
raise ValueError("Last message must be a HumanMessage!")
msg = "Last message must be a HumanMessage!"
raise ValueError(msg)
messages_dicts = [self._to_chatml_format(m) for m in messages]
@ -734,7 +735,6 @@ class ChatHuggingFace(BaseChatModel):
def _to_chatml_format(self, message: BaseMessage) -> dict:
"""Convert LangChain message to ChatML format."""
if isinstance(message, SystemMessage):
role = "system"
elif isinstance(message, AIMessage):
@ -742,7 +742,8 @@ class ChatHuggingFace(BaseChatModel):
elif isinstance(message, HumanMessage):
role = "user"
else:
raise ValueError(f"Unknown message type: {type(message)}")
msg = f"Unknown message type: {type(message)}"
raise ValueError(msg)
return {"role": role, "content": message.content}
@ -761,8 +762,7 @@ class ChatHuggingFace(BaseChatModel):
)
def _resolve_model_id(self) -> None:
"""Resolve the model_id from the LLM's inference_server_url"""
"""Resolve the model_id from the LLM's inference_server_url."""
from huggingface_hub import list_inference_endpoints # type: ignore[import]
if _is_huggingface_hub(self.llm) or (
@ -770,9 +770,9 @@ class ChatHuggingFace(BaseChatModel):
):
self.model_id = self.llm.repo_id
return
elif _is_huggingface_textgen_inference(self.llm):
if _is_huggingface_textgen_inference(self.llm):
endpoint_url: Optional[str] = self.llm.inference_server_url
elif _is_huggingface_pipeline(self.llm):
if _is_huggingface_pipeline(self.llm):
from transformers import AutoTokenizer # type: ignore[import]
self.model_id = self.model_id or self.llm.model_id
@ -782,29 +782,29 @@ class ChatHuggingFace(BaseChatModel):
else self.tokenizer
)
return
elif _is_huggingface_endpoint(self.llm):
if _is_huggingface_endpoint(self.llm):
self.model_id = self.llm.repo_id or self.llm.model
return
else:
endpoint_url = self.llm.endpoint_url
endpoint_url = self.llm.endpoint_url
available_endpoints = list_inference_endpoints("*")
for endpoint in available_endpoints:
if endpoint.url == endpoint_url:
self.model_id = endpoint.repository
if not self.model_id:
raise ValueError(
msg = (
"Failed to resolve model_id:"
f"Could not find model id for inference server: {endpoint_url}"
"Make sure that your Hugging Face token has access to the endpoint."
)
raise ValueError(msg)
def bind_tools(
self,
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
*,
tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required"], bool]
Union[dict, str, Literal["auto", "none", "required"], bool] # noqa: PYI051
] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
@ -823,15 +823,16 @@ class ChatHuggingFace(BaseChatModel):
{"type": "function", "function": {"name": <<tool_name>>}}.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
if tool_choice is not None and tool_choice:
if len(formatted_tools) != 1:
raise ValueError(
msg = (
"When specifying `tool_choice`, you must provide exactly one "
f"tool. Received {len(formatted_tools)} tools."
)
raise ValueError(msg)
if isinstance(tool_choice, str):
if tool_choice not in ("auto", "none", "required"):
tool_choice = {
@ -845,15 +846,17 @@ class ChatHuggingFace(BaseChatModel):
formatted_tools[0]["function"]["name"]
!= tool_choice["function"]["name"]
):
raise ValueError(
msg = (
f"Tool choice {tool_choice} was specified, but the only "
f"provided tool was {formatted_tools[0]['function']['name']}."
)
raise ValueError(msg)
else:
raise ValueError(
msg = (
f"Unrecognized tool_choice type. Expected str, bool or dict. "
f"Received: {tool_choice}"
)
raise ValueError(msg)
kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs)
@ -892,6 +895,11 @@ class ChatHuggingFace(BaseChatModel):
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
kwargs:
Additional parameters to pass to the underlying LLM's
:meth:`langchain_core.language_models.chat.BaseChatModel.bind`
method, such as `response_format` or `ls_structured_output_format`.
Returns:
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
@ -908,14 +916,16 @@ class ChatHuggingFace(BaseChatModel):
""" # noqa: E501
_ = kwargs.pop("strict", None)
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
msg = f"Received unsupported arguments {kwargs}"
raise ValueError(msg)
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
if method == "function_calling":
if schema is None:
raise ValueError(
msg = (
"schema must be specified when method is 'function_calling'. "
"Received None."
)
raise ValueError(msg)
formatted_tool = convert_to_openai_tool(schema)
tool_name = formatted_tool["function"]["name"]
llm = self.bind_tools(
@ -927,19 +937,18 @@ class ChatHuggingFace(BaseChatModel):
},
)
if is_pydantic_schema:
raise NotImplementedError(
"Pydantic schema is not supported for function calling"
)
else:
output_parser: Union[JsonOutputKeyToolsParser, JsonOutputParser] = (
JsonOutputKeyToolsParser(key_name=tool_name, first_tool_only=True)
)
msg = "Pydantic schema is not supported for function calling"
raise NotImplementedError(msg)
output_parser: Union[JsonOutputKeyToolsParser, JsonOutputParser] = (
JsonOutputKeyToolsParser(key_name=tool_name, first_tool_only=True)
)
elif method == "json_schema":
if schema is None:
raise ValueError(
msg = (
"schema must be specified when method is 'json_schema'. "
"Received None."
)
raise ValueError(msg)
formatted_schema = convert_to_json_schema(schema)
llm = self.bind(
response_format={"type": "json_object", "schema": formatted_schema},
@ -963,10 +972,11 @@ class ChatHuggingFace(BaseChatModel):
JsonOutputKeyToolsParser, JsonOutputParser
] = JsonOutputParser() # type: ignore[arg-type]
else:
raise ValueError(
msg = (
f"Unrecognized method argument. Expected one of 'function_calling' or "
f"'json_mode'. Received: '{method}'"
)
raise ValueError(msg)
if include_raw:
parser_assign = RunnablePassthrough.assign(
@ -977,8 +987,7 @@ class ChatHuggingFace(BaseChatModel):
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
return llm | output_parser
def _create_message_dicts(
self, messages: list[BaseMessage], stop: Optional[list[str]]
@ -991,8 +1000,7 @@ class ChatHuggingFace(BaseChatModel):
@property
def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling Hugging Face
Inference Providers API."""
"""Get default parameters for calling Hugging Face Inference Providers API."""
params = {
"model": self.model_id,
"stream": self.streaming,

View File

@ -1,9 +1,11 @@
from __future__ import annotations
from typing import Any, Optional
from langchain_core.embeddings import Embeddings
from pydantic import BaseModel, ConfigDict, Field
from ..utils.import_utils import (
from langchain_huggingface.utils.import_utils import (
IMPORT_ERROR,
is_ipex_available,
is_optimum_intel_available,
@ -33,12 +35,13 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
"""
model_name: str = Field(default=DEFAULT_MODEL_NAME, alias="model")
"""Model name to use."""
cache_folder: Optional[str] = None
"""Path to store models.
"""Path to store models.
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass to the Sentence Transformer model, such as `device`,
@ -46,12 +49,12 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer"""
encode_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass when calling the `encode` method for the documents of
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
`precision`, `normalize_embeddings`, and more.
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""
query_encode_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass when calling the `encode` method for the query of
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
`precision`, `normalize_embeddings`, and more.
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""
multi_process: bool = False
@ -65,24 +68,25 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
try:
import sentence_transformers # type: ignore[import]
except ImportError as exc:
raise ImportError(
msg = (
"Could not import sentence_transformers python package. "
"Please install it with `pip install sentence-transformers`."
) from exc
)
raise ImportError(msg) from exc
if self.model_kwargs.get("backend", "torch") == "ipex":
if not is_optimum_intel_available() or not is_ipex_available():
raise ImportError(
f'Backend: ipex {IMPORT_ERROR.format("optimum[ipex]")}'
)
msg = f'Backend: ipex {IMPORT_ERROR.format("optimum[ipex]")}'
raise ImportError(msg)
if is_optimum_intel_version("<", _MIN_OPTIMUM_VERSION):
raise ImportError(
msg = (
f"Backend: ipex requires optimum-intel>="
f"{_MIN_OPTIMUM_VERSION}. You can install it with pip: "
"`pip install --upgrade --upgrade-strategy eager "
"`optimum[ipex]`."
)
raise ImportError(msg)
from optimum.intel import IPEXSentenceTransformer # type: ignore[import]
@ -104,21 +108,21 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
def _embed(
self, texts: list[str], encode_kwargs: dict[str, Any]
) -> list[list[float]]:
"""
Embed a text using the HuggingFace transformer model.
"""Embed a text using the HuggingFace transformer model.
Args:
texts: The list of texts to embed.
encode_kwargs: Keyword arguments to pass when calling the
`encode` method for the documents of the SentenceTransformer
encode method.
encode method.
Returns:
List of embeddings, one for each text.
"""
import sentence_transformers # type: ignore[import]
texts = list(map(lambda x: x.replace("\n", " "), texts))
texts = [x.replace("\n", " ") for x in texts]
if self.multi_process:
pool = self._client.start_multi_process_pool()
embeddings = self._client.encode_multi_process(texts, pool)
@ -127,16 +131,17 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
embeddings = self._client.encode(
texts,
show_progress_bar=self.show_progress,
**encode_kwargs, # type: ignore
**encode_kwargs,
)
if isinstance(embeddings, list):
raise TypeError(
msg = (
"Expected embeddings to be a Tensor or a numpy array, "
"got a list instead."
)
raise TypeError(msg)
return embeddings.tolist()
return embeddings.tolist() # type: ignore[return-type]
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Compute doc embeddings using a HuggingFace transformer model.
@ -146,6 +151,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
Returns:
List of embeddings, one for each text.
"""
return self._embed(texts, self.encode_kwargs)
@ -157,6 +163,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
Returns:
Embeddings for the text.
"""
embed_kwargs = (
self.query_encode_kwargs

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import os
from typing import Any, Optional
@ -27,6 +29,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
task="feature-extraction",
huggingfacehub_api_token="my-api-key",
)
"""
client: Any = None #: :meta private:
@ -35,7 +38,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
"""Model name to use."""
provider: Optional[str] = None
"""Name of the provider to use for inference with the model specified in
``repo_id``. e.g. "sambanova". if not specified, defaults to HF Inference API.
``repo_id``. e.g. "sambanova". if not specified, defaults to HF Inference API.
available providers can be found in the [huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/guides/inference#supported-providers-and-tasks)."""
repo_id: Optional[str] = None
"""Huggingfacehub repository id, for backward compatibility."""
@ -87,18 +90,20 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
)
if self.task not in VALID_TASKS:
raise ValueError(
msg = (
f"Got invalid task {self.task}, "
f"currently only {VALID_TASKS} are supported"
)
raise ValueError(msg)
self.client = client
self.async_client = async_client
except ImportError:
raise ImportError(
msg = (
"Could not import huggingface_hub python package. "
"Please install it with `pip install huggingface_hub`."
)
raise ImportError(msg)
return self
def embed_documents(self, texts: list[str]) -> list[list[float]]:
@ -109,6 +114,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
Returns:
List of embeddings, one for each text.
"""
# replace newlines, which can negatively affect performance.
texts = [text.replace("\n", " ") for text in texts]
@ -125,6 +131,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
Returns:
List of embeddings, one for each text.
"""
# replace newlines, which can negatively affect performance.
texts = [text.replace("\n", " ") for text in texts]
@ -142,9 +149,9 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
Returns:
Embeddings for the text.
"""
response = self.embed_documents([text])[0]
return response
return self.embed_documents([text])[0]
async def aembed_query(self, text: str) -> list[float]:
"""Async Call to HuggingFaceHub's embedding endpoint for embedding query text.
@ -154,6 +161,6 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
Returns:
Embeddings for the text.
"""
response = (await self.aembed_documents([text]))[0]
return response
return (await self.aembed_documents([text]))[0]

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import inspect
import logging
import os
@ -25,8 +27,7 @@ VALID_TASKS = (
class HuggingFaceEndpoint(LLM):
"""
Hugging Face Endpoint. This works with any model that supports text generation (i.e. text completion) task.
"""Hugging Face Endpoint. This works with any model that supports text generation (i.e. text completion) task.
To use this class, you should have installed the ``huggingface_hub`` package, and
the environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token,
@ -75,6 +76,7 @@ class HuggingFaceEndpoint(LLM):
huggingfacehub_api_token="my-api-key"
)
print(llm.invoke("What is Deep Learning?"))
""" # noqa: E501
endpoint_url: Optional[str] = None
@ -147,7 +149,8 @@ class HuggingFaceEndpoint(LLM):
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
msg = f"Found {field_name} supplied twice."
raise ValueError(msg)
if field_name not in all_required_field_names:
logger.warning(
f"""WARNING! {field_name} is not default parameter.
@ -158,10 +161,11 @@ class HuggingFaceEndpoint(LLM):
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
msg = (
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
raise ValueError(msg)
values["model_kwargs"] = extra
@ -185,18 +189,20 @@ class HuggingFaceEndpoint(LLM):
repo_id = values.get("repo_id")
if sum([bool(model), bool(endpoint_url), bool(repo_id)]) > 1:
raise ValueError(
msg = (
"Please specify either a `model` OR an `endpoint_url` OR a `repo_id`,"
"not more than one."
)
raise ValueError(msg)
values["model"] = (
model or endpoint_url or repo_id or os.environ.get("HF_INFERENCE_ENDPOINT")
)
if not values["model"]:
raise ValueError(
msg = (
"Please specify a `model` or an `endpoint_url` or a `repo_id` for the "
"model."
)
raise ValueError(msg)
return values
@model_validator(mode="after")
@ -274,12 +280,10 @@ class HuggingFaceEndpoint(LLM):
"""Get the identifying parameters."""
_model_kwargs = self.model_kwargs or {}
return {
**{
"endpoint_url": self.endpoint_url,
"task": self.task,
"provider": self.provider,
},
**{"model_kwargs": _model_kwargs},
"endpoint_url": self.endpoint_url,
"task": self.task,
"provider": self.provider,
"model_kwargs": _model_kwargs,
}
@property
@ -310,19 +314,19 @@ class HuggingFaceEndpoint(LLM):
):
completion += chunk.text
return completion
else:
response_text = self.client.text_generation(
prompt=prompt,
model=self.model,
**invocation_params,
)
# Maybe the generation has stopped at one of the stop sequences:
# then we remove this stop sequence from the end of the generated text
for stop_seq in invocation_params["stop"]:
if response_text[-len(stop_seq) :] == stop_seq:
response_text = response_text[: -len(stop_seq)]
return response_text
response_text = self.client.text_generation(
prompt=prompt,
model=self.model,
**invocation_params,
)
# Maybe the generation has stopped at one of the stop sequences:
# then we remove this stop sequence from the end of the generated text
for stop_seq in invocation_params["stop"]:
if response_text[-len(stop_seq) :] == stop_seq:
response_text = response_text[: -len(stop_seq)]
return response_text
async def _acall(
self,
@ -339,20 +343,20 @@ class HuggingFaceEndpoint(LLM):
):
completion += chunk.text
return completion
else:
response_text = await self.async_client.text_generation(
prompt=prompt,
**invocation_params,
model=self.model,
stream=False,
)
# Maybe the generation has stopped at one of the stop sequences:
# then remove this stop sequence from the end of the generated text
for stop_seq in invocation_params["stop"]:
if response_text[-len(stop_seq) :] == stop_seq:
response_text = response_text[: -len(stop_seq)]
return response_text
response_text = await self.async_client.text_generation(
prompt=prompt,
**invocation_params,
model=self.model,
stream=False,
)
# Maybe the generation has stopped at one of the stop sequences:
# then remove this stop sequence from the end of the generated text
for stop_seq in invocation_params["stop"]:
if response_text[-len(stop_seq) :] == stop_seq:
response_text = response_text[: -len(stop_seq)]
return response_text
def _stream(
self,

View File

@ -10,7 +10,7 @@ from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from pydantic import ConfigDict, model_validator
from ..utils.import_utils import (
from langchain_huggingface.utils.import_utils import (
IMPORT_ERROR,
is_ipex_available,
is_openvino_available,
@ -87,7 +87,7 @@ class HuggingFacePipeline(BaseLLM):
def pre_init_validator(cls, values: dict[str, Any]) -> dict[str, Any]:
"""Ensure model_id is set either by pipeline or user input."""
if "model_id" not in values:
if "pipeline" in values and values["pipeline"]:
if values.get("pipeline"):
values["model_id"] = values["pipeline"].model.name_or_path
else:
values["model_id"] = DEFAULT_MODEL_ID
@ -116,34 +116,38 @@ class HuggingFacePipeline(BaseLLM):
from transformers import pipeline as hf_pipeline # type: ignore[import]
except ImportError:
raise ValueError(
msg = (
"Could not import transformers python package. "
"Please install it with `pip install transformers`."
)
raise ValueError(msg)
_model_kwargs = model_kwargs.copy() if model_kwargs else {}
if device_map is not None:
if device is not None:
raise ValueError(
msg = (
"Both `device` and `device_map` are specified. "
"`device` will override `device_map`. "
"You will most likely encounter unexpected behavior."
"Please remove `device` and keep "
"`device_map`."
)
raise ValueError(msg)
if "device_map" in _model_kwargs:
raise ValueError("`device_map` is already specified in `model_kwargs`.")
msg = "`device_map` is already specified in `model_kwargs`."
raise ValueError(msg)
_model_kwargs["device_map"] = device_map
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
if backend in {"openvino", "ipex"}:
if task not in VALID_TASKS:
raise ValueError(
msg = (
f"Got invalid task {task}, "
f"currently only {VALID_TASKS} are supported"
)
raise ValueError(msg)
err_msg = f'Backend: {backend} {IMPORT_ERROR.format(f"optimum[{backend}]")}'
if not is_optimum_intel_available():
@ -156,12 +160,13 @@ class HuggingFacePipeline(BaseLLM):
else _MIN_OPTIMUM_VERSION
)
if is_optimum_intel_version("<", min_optimum_version):
raise ImportError(
msg = (
f"Backend: {backend} requires optimum-intel>="
f"{min_optimum_version}. You can install it with pip: "
"`pip install --upgrade --upgrade-strategy eager "
f"`optimum[{backend}]`."
)
raise ImportError(msg)
if backend == "openvino":
if not is_openvino_available():
@ -240,10 +245,11 @@ class HuggingFacePipeline(BaseLLM):
cuda_device_count = torch.cuda.device_count()
if device < -1 or (device >= cuda_device_count):
raise ValueError(
msg = (
f"Got device=={device}, "
f"device is required to be within [-1, {cuda_device_count})"
)
raise ValueError(msg)
if device_map is not None and device < 0:
device = None
if device is not None and device < 0 and cuda_device_count > 0:
@ -271,10 +277,11 @@ class HuggingFacePipeline(BaseLLM):
**_pipeline_kwargs,
)
if pipeline.task not in VALID_TASKS:
raise ValueError(
msg = (
f"Got invalid task {pipeline.task}, "
f"currently only {VALID_TASKS} are supported"
)
raise ValueError(msg)
return cls(
pipeline=pipeline,
model_id=model_id,
@ -324,21 +331,22 @@ class HuggingFacePipeline(BaseLLM):
# if model returns multiple generations, pick the top one
response = response[0]
if self.pipeline.task == "text-generation":
text = response["generated_text"]
elif self.pipeline.task == "text2text-generation":
text = response["generated_text"]
elif self.pipeline.task == "image-text-to-text":
if (
self.pipeline.task == "text-generation"
or self.pipeline.task == "text2text-generation"
or self.pipeline.task == "image-text-to-text"
):
text = response["generated_text"]
elif self.pipeline.task == "summarization":
text = response["summary_text"]
elif self.pipeline.task in "translation":
text = response["translation_text"]
else:
raise ValueError(
msg = (
f"Got invalid task {self.pipeline.task}, "
f"currently only {VALID_TASKS} are supported"
)
raise ValueError(msg)
if skip_prompt:
text = text[len(batch_prompts[j]) :]
# Append the processed text to results
@ -378,10 +386,7 @@ class HuggingFacePipeline(BaseLLM):
scores: torch.FloatTensor,
**kwargs: Any,
) -> bool:
for stop_id in stopping_ids_list:
if input_ids[0][-1] == stop_id:
return True
return False
return any(input_ids[0][-1] == stop_id for stop_id in stopping_ids_list)
stopping_criteria = StoppingCriteriaList([StopOnTokens()])

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import importlib.metadata
import importlib.util
import operator as op
@ -46,8 +48,7 @@ def compare_versions(
operation: str,
requirement_version: str,
) -> bool:
"""
Compare a library version to some requirement using a given operation.
"""Compare a library version to some requirement using a given operation.
Arguments:
library_or_version (`str` or `packaging.version.Version`):
@ -56,12 +57,14 @@ def compare_versions(
A string representation of an operator, such as `">"` or `"<="`.
requirement_version (`str`):
The version to compare the library version against
"""
if operation not in STR_OPERATION_TO_FUNC.keys():
raise ValueError(
if operation not in STR_OPERATION_TO_FUNC:
msg = (
f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}"
f", received {operation}"
)
raise ValueError(msg)
if isinstance(library_or_version, str):
library_or_version = version.parse(
importlib.metadata.version(library_or_version)
@ -88,9 +91,7 @@ def is_openvino_available() -> bool:
def is_optimum_version(operation: str, reference_version: str) -> bool:
"""
Compare the current Optimum version to a given reference with an operation.
"""
"""Compare the current Optimum version to a given reference with an operation."""
if not _optimum_version:
return False
return compare_versions(
@ -99,9 +100,7 @@ def is_optimum_version(operation: str, reference_version: str) -> bool:
def is_optimum_intel_version(operation: str, reference_version: str) -> bool:
"""
Compare the current Optimum Intel version to a given reference with an operation.
"""
"""Compare current Optimum Intel version to a given reference with an operation."""
if not _optimum_intel_version:
return False
return compare_versions(

View File

@ -58,8 +58,60 @@ disallow_untyped_defs = "True"
target-version = "py39"
[tool.ruff.lint]
select = ["E", "F", "I", "T201", "UP", "S"]
ignore = ["UP007"]
select = [
"A", # flake8-builtins
"ASYNC", # flake8-async
"C4", # flake8-comprehensions
"COM", # flake8-commas
"D", # pydocstyle
"DOC", # pydoclint
"E", # pycodestyle error
"EM", # flake8-errmsg
"F", # pyflakes
"FA", # flake8-future-annotations
"FBT", # flake8-boolean-trap
"FLY", # flake8-flynt
"I", # isort
"ICN", # flake8-import-conventions
"INT", # flake8-gettext
"ISC", # isort-comprehensions
"PGH", # pygrep-hooks
"PIE", # flake8-pie
"PERF", # flake8-perf
"PYI", # flake8-pyi
"Q", # flake8-quotes
"RET", # flake8-return
"RSE", # flake8-rst-docstrings
"RUF", # ruff
"S", # flake8-bandit
"SLF", # flake8-self
"SLOT", # flake8-slots
"SIM", # flake8-simplify
"T10", # flake8-debugger
"T20", # flake8-print
"TID", # flake8-tidy-imports
"UP", # pyupgrade
"W", # pycodestyle warning
"YTT", # flake8-2020
]
ignore = [
"D100", # pydocstyle: Missing docstring in public module
"D101", # pydocstyle: Missing docstring in public class
"D102", # pydocstyle: Missing docstring in public method
"D103", # pydocstyle: Missing docstring in public function
"D104", # pydocstyle: Missing docstring in public package
"D105", # pydocstyle: Missing docstring in magic method
"D107", # pydocstyle: Missing docstring in __init__
"D407", # pydocstyle: Missing-dashed-underline-after-section
"COM812", # Messes with the formatter
"ISC001", # Messes with the formatter
"PERF203", # Rarely useful
"S112", # Rarely useful
"RUF012", # Doesn't play well with Pydantic
"SLF001", # Private member access
"UP007", # pyupgrade: non-pep604-annotation-union
]
[tool.coverage.run]
omit = ["tests/*"]

View File

@ -4,4 +4,3 @@ import pytest # type: ignore[import-not-found, import-not-found]
@pytest.mark.compile
def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests."""
pass

View File

@ -1,4 +1,4 @@
"""Standard LangChain interface tests"""
"""Standard LangChain interface tests."""
import pytest
from langchain_core.language_models import BaseChatModel

View File

@ -31,8 +31,7 @@ def mock_llm() -> Mock:
"langchain_huggingface.chat_models.huggingface.ChatHuggingFace._resolve_model_id"
)
def chat_hugging_face(mock_resolve_id: Any, mock_llm: Any) -> ChatHuggingFace:
chat_hf = ChatHuggingFace(llm=mock_llm, tokenizer=MagicMock())
return chat_hf
return ChatHuggingFace(llm=mock_llm, tokenizer=MagicMock())
def test_create_chat_result(chat_hugging_face: Any) -> None:

View File

@ -6,8 +6,7 @@ DEFAULT_MODEL_ID = "gpt2"
def test_initialization_default() -> None:
"""Test default initialization"""
"""Test default initialization."""
llm = HuggingFacePipeline()
assert llm.model_id == DEFAULT_MODEL_ID
@ -15,8 +14,7 @@ def test_initialization_default() -> None:
@patch("transformers.pipeline")
def test_initialization_with_pipeline(mock_pipeline: MagicMock) -> None:
"""Test initialization with a pipeline object"""
"""Test initialization with a pipeline object."""
mock_pipe = MagicMock()
mock_pipe.model.name_or_path = "mock-model-id"
mock_pipeline.return_value = mock_pipe
@ -32,8 +30,7 @@ def test_initialization_with_pipeline(mock_pipeline: MagicMock) -> None:
def test_initialization_with_from_model_id(
mock_pipeline: MagicMock, mock_model: MagicMock, mock_tokenizer: MagicMock
) -> None:
"""Test initialization with the from_model_id method"""
"""Test initialization with the from_model_id method."""
mock_tokenizer.return_value = MagicMock(pad_token_id=0)
mock_model.return_value = MagicMock()