mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-12 15:59:56 +00:00
huggingface[patch]: ruff fixes and rules (#31912)
* bump ruff deps * add more thorough ruff rules * fix said rules
This commit is contained in:
parent
06ab2972e3
commit
750721b4c3
@ -87,7 +87,16 @@ select = [
|
||||
"W", # pycodestyle warning
|
||||
"YTT", # flake8-2020
|
||||
]
|
||||
ignore = ["D100", "D101", "D102", "D103", "D104", "D105", "D107"]
|
||||
ignore = [
|
||||
"D100", # pydocstyle: Missing docstring in public module
|
||||
"D101", # pydocstyle: Missing docstring in public class
|
||||
"D102", # pydocstyle: Missing docstring in public method
|
||||
"D103", # pydocstyle: Missing docstring in public function
|
||||
"D104", # pydocstyle: Missing docstring in public package
|
||||
"D105", # pydocstyle: Missing docstring in magic method
|
||||
"D107", # pydocstyle: Missing docstring in __init__
|
||||
"D407", # pydocstyle: Missing-dashed-underline-after-section
|
||||
]
|
||||
pyupgrade.keep-runtime-typing = true
|
||||
|
||||
[tool.mypy]
|
||||
|
@ -180,13 +180,14 @@ select = [
|
||||
"YTT", # flake8-2020
|
||||
]
|
||||
ignore = [
|
||||
"D100", # Missing docstring in public module
|
||||
"D101", # Missing docstring in public class
|
||||
"D102", # Missing docstring in public method
|
||||
"D103", # Missing docstring in public function
|
||||
"D104", # Missing docstring in public package
|
||||
"D105", # Missing docstring in magic method
|
||||
"D107", # Missing docstring in __init__
|
||||
"D100", # pydocstyle: Missing docstring in public module
|
||||
"D101", # pydocstyle: Missing docstring in public class
|
||||
"D102", # pydocstyle: Missing docstring in public method
|
||||
"D103", # pydocstyle: Missing docstring in public function
|
||||
"D104", # pydocstyle: Missing docstring in public package
|
||||
"D105", # pydocstyle: Missing docstring in magic method
|
||||
"D107", # pydocstyle: Missing docstring in __init__
|
||||
"D407", # pydocstyle: Missing-dashed-underline-after-section
|
||||
"COM812", # Messes with the formatter
|
||||
"ISC001", # Messes with the formatter
|
||||
"PERF203", # Rarely useful
|
||||
|
@ -97,13 +97,14 @@ select = [
|
||||
"YTT", # flake8-2020
|
||||
]
|
||||
ignore = [
|
||||
"D100", # Missing docstring in public module
|
||||
"D101", # Missing docstring in public class
|
||||
"D102", # Missing docstring in public method
|
||||
"D103", # Missing docstring in public function
|
||||
"D104", # Missing docstring in public package
|
||||
"D105", # Missing docstring in magic method
|
||||
"D107", # Missing docstring in __init__
|
||||
"D100", # pydocstyle: Missing docstring in public module
|
||||
"D101", # pydocstyle: Missing docstring in public class
|
||||
"D102", # pydocstyle: Missing docstring in public method
|
||||
"D103", # pydocstyle: Missing docstring in public function
|
||||
"D104", # pydocstyle: Missing docstring in public package
|
||||
"D105", # pydocstyle: Missing docstring in magic method
|
||||
"D107", # pydocstyle: Missing docstring in __init__
|
||||
"D407", # pydocstyle: Missing-dashed-underline-after-section
|
||||
"D214", # Section over-indented, doesn't play well with reStructuredText
|
||||
"COM812", # Messes with the formatter
|
||||
"ISC001", # Messes with the formatter
|
||||
|
@ -343,7 +343,6 @@ def function() -> Callable:
|
||||
"""Dummy function.
|
||||
|
||||
Args:
|
||||
----
|
||||
arg1: foo
|
||||
arg2: one of 'bar', 'baz'
|
||||
|
||||
|
@ -1,5 +1,8 @@
|
||||
"""Hugging Face Chat Wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
@ -58,8 +61,8 @@ from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||
from ..llms.huggingface_pipeline import HuggingFacePipeline
|
||||
from langchain_huggingface.llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||
from langchain_huggingface.llms.huggingface_pipeline import HuggingFacePipeline
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -111,6 +114,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
|
||||
Returns:
|
||||
The dictionary.
|
||||
|
||||
"""
|
||||
message_dict: dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
@ -153,7 +157,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
"tool_call_id": message.tool_call_id,
|
||||
}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
msg = f"Got unknown type {message}"
|
||||
raise TypeError(msg)
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
@ -167,11 +172,12 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
|
||||
Returns:
|
||||
The LangChain message.
|
||||
|
||||
"""
|
||||
role = _dict.get("role")
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict.get("content", ""))
|
||||
elif role == "assistant":
|
||||
if role == "assistant":
|
||||
content = _dict.get("content", "") or ""
|
||||
additional_kwargs: dict = {}
|
||||
if function_call := _dict.get("function_call"):
|
||||
@ -193,13 +199,13 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
tool_calls=tool_calls,
|
||||
invalid_tool_calls=invalid_tool_calls,
|
||||
)
|
||||
elif role == "system":
|
||||
if role == "system":
|
||||
return SystemMessage(content=_dict.get("content", ""))
|
||||
elif role == "function":
|
||||
if role == "function":
|
||||
return FunctionMessage(
|
||||
content=_dict.get("content", ""), name=_dict.get("name", "")
|
||||
)
|
||||
elif role == "tool":
|
||||
if role == "tool":
|
||||
additional_kwargs = {}
|
||||
if "name" in _dict:
|
||||
additional_kwargs["name"] = _dict["name"]
|
||||
@ -208,8 +214,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
tool_call_id=_dict.get("tool_call_id", ""),
|
||||
additional_kwargs=additional_kwargs,
|
||||
)
|
||||
else:
|
||||
return ChatMessage(content=_dict.get("content", ""), role=role or "")
|
||||
return ChatMessage(content=_dict.get("content", ""), role=role or "")
|
||||
|
||||
|
||||
def _is_huggingface_hub(llm: Any) -> bool:
|
||||
@ -241,7 +246,7 @@ def _convert_chunk_to_message_chunk(
|
||||
if raw_tool_calls := _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||
for rtc in raw_tool_calls:
|
||||
try:
|
||||
with contextlib.suppress(KeyError):
|
||||
tool_call_chunks.append(
|
||||
create_tool_call_chunk(
|
||||
name=rtc["function"].get("name"),
|
||||
@ -250,11 +255,9 @@ def _convert_chunk_to_message_chunk(
|
||||
index=rtc.get("index"),
|
||||
)
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
if role == "assistant" or default_class == AIMessageChunk:
|
||||
if usage := chunk.get("usage"):
|
||||
input_tokens = usage.get("prompt_tokens", 0)
|
||||
output_tokens = usage.get("completion_tokens", 0)
|
||||
@ -271,16 +274,15 @@ def _convert_chunk_to_message_chunk(
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
||||
)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
if role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role == "function" or default_class == FunctionMessageChunk:
|
||||
if role == "function" or default_class == FunctionMessageChunk:
|
||||
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||
elif role == "tool" or default_class == ToolMessageChunk:
|
||||
if role == "tool" or default_class == ToolMessageChunk:
|
||||
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
if role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
else:
|
||||
return default_class(content=content) # type: ignore
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def _is_huggingface_textgen_inference(llm: Any) -> bool:
|
||||
@ -304,7 +306,7 @@ def _is_huggingface_pipeline(llm: Any) -> bool:
|
||||
|
||||
|
||||
class ChatHuggingFace(BaseChatModel):
|
||||
"""Hugging Face LLM's as ChatModels.
|
||||
r"""Hugging Face LLM's as ChatModels.
|
||||
|
||||
Works with `HuggingFaceTextGenInference`, `HuggingFaceEndpoint`,
|
||||
`HuggingFaceHub`, and `HuggingFacePipeline` LLMs.
|
||||
@ -484,7 +486,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
"""Whether to return logprobs."""
|
||||
top_logprobs: Optional[int] = None
|
||||
"""Number of most likely tokens to return at each token position, each with
|
||||
an associated log probability. `logprobs` must be set to true
|
||||
an associated log probability. `logprobs` must be set to true
|
||||
if this parameter is used."""
|
||||
logit_bias: Optional[dict[int, int]] = None
|
||||
"""Modify the likelihood of specified tokens appearing in the completion."""
|
||||
@ -511,11 +513,12 @@ class ChatHuggingFace(BaseChatModel):
|
||||
and not _is_huggingface_endpoint(self.llm)
|
||||
and not _is_huggingface_pipeline(self.llm)
|
||||
):
|
||||
raise TypeError(
|
||||
msg = (
|
||||
"Expected llm to be one of HuggingFaceTextGenInference, "
|
||||
"HuggingFaceEndpoint, HuggingFaceHub, HuggingFacePipeline "
|
||||
f"received {type(self.llm)}"
|
||||
)
|
||||
raise TypeError(msg)
|
||||
return self
|
||||
|
||||
def _create_chat_result(self, response: dict) -> ChatResult:
|
||||
@ -529,7 +532,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
"output_tokens": token_usage.get("completion_tokens", 0),
|
||||
"total_tokens": token_usage.get("total_tokens", 0),
|
||||
}
|
||||
generation_info = dict(finish_reason=res.get("finish_reason"))
|
||||
generation_info = {"finish_reason": res.get("finish_reason")}
|
||||
if "logprobs" in res:
|
||||
generation_info["logprobs"] = res["logprobs"]
|
||||
gen = ChatGeneration(
|
||||
@ -558,7 +561,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
answer = self.llm.client.chat(messages=message_dicts, **kwargs)
|
||||
return self._create_chat_result(answer)
|
||||
elif _is_huggingface_endpoint(self.llm):
|
||||
if _is_huggingface_endpoint(self.llm):
|
||||
if should_stream:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
@ -573,18 +576,17 @@ class ChatHuggingFace(BaseChatModel):
|
||||
}
|
||||
answer = self.llm.client.chat_completion(messages=message_dicts, **params)
|
||||
return self._create_chat_result(answer)
|
||||
else:
|
||||
llm_input = self._to_chat_prompt(messages)
|
||||
llm_input = self._to_chat_prompt(messages)
|
||||
|
||||
if should_stream:
|
||||
stream_iter = self.llm._stream(
|
||||
llm_input, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
llm_result = self.llm._generate(
|
||||
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
|
||||
if should_stream:
|
||||
stream_iter = self.llm._stream(
|
||||
llm_input, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return self._to_chat_result(llm_result)
|
||||
return generate_from_stream(stream_iter)
|
||||
llm_result = self.llm._generate(
|
||||
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return self._to_chat_result(llm_result)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
@ -598,7 +600,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
answer = await self.llm.async_client.chat(messages=message_dicts, **kwargs)
|
||||
return self._create_chat_result(answer)
|
||||
elif _is_huggingface_endpoint(self.llm):
|
||||
if _is_huggingface_endpoint(self.llm):
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._astream(
|
||||
@ -616,17 +618,14 @@ class ChatHuggingFace(BaseChatModel):
|
||||
messages=message_dicts, **params
|
||||
)
|
||||
return self._create_chat_result(answer)
|
||||
|
||||
elif _is_huggingface_pipeline(self.llm):
|
||||
raise NotImplementedError(
|
||||
"async generation is not supported with HuggingFacePipeline"
|
||||
)
|
||||
else:
|
||||
llm_input = self._to_chat_prompt(messages)
|
||||
llm_result = await self.llm._agenerate(
|
||||
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return self._to_chat_result(llm_result)
|
||||
if _is_huggingface_pipeline(self.llm):
|
||||
msg = "async generation is not supported with HuggingFacePipeline"
|
||||
raise NotImplementedError(msg)
|
||||
llm_input = self._to_chat_prompt(messages)
|
||||
llm_result = await self.llm._agenerate(
|
||||
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return self._to_chat_result(llm_result)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
@ -721,10 +720,12 @@ class ChatHuggingFace(BaseChatModel):
|
||||
) -> str:
|
||||
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
|
||||
if not messages:
|
||||
raise ValueError("At least one HumanMessage must be provided!")
|
||||
msg = "At least one HumanMessage must be provided!"
|
||||
raise ValueError(msg)
|
||||
|
||||
if not isinstance(messages[-1], HumanMessage):
|
||||
raise ValueError("Last message must be a HumanMessage!")
|
||||
msg = "Last message must be a HumanMessage!"
|
||||
raise ValueError(msg)
|
||||
|
||||
messages_dicts = [self._to_chatml_format(m) for m in messages]
|
||||
|
||||
@ -734,7 +735,6 @@ class ChatHuggingFace(BaseChatModel):
|
||||
|
||||
def _to_chatml_format(self, message: BaseMessage) -> dict:
|
||||
"""Convert LangChain message to ChatML format."""
|
||||
|
||||
if isinstance(message, SystemMessage):
|
||||
role = "system"
|
||||
elif isinstance(message, AIMessage):
|
||||
@ -742,7 +742,8 @@ class ChatHuggingFace(BaseChatModel):
|
||||
elif isinstance(message, HumanMessage):
|
||||
role = "user"
|
||||
else:
|
||||
raise ValueError(f"Unknown message type: {type(message)}")
|
||||
msg = f"Unknown message type: {type(message)}"
|
||||
raise ValueError(msg)
|
||||
|
||||
return {"role": role, "content": message.content}
|
||||
|
||||
@ -761,8 +762,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
)
|
||||
|
||||
def _resolve_model_id(self) -> None:
|
||||
"""Resolve the model_id from the LLM's inference_server_url"""
|
||||
|
||||
"""Resolve the model_id from the LLM's inference_server_url."""
|
||||
from huggingface_hub import list_inference_endpoints # type: ignore[import]
|
||||
|
||||
if _is_huggingface_hub(self.llm) or (
|
||||
@ -770,9 +770,9 @@ class ChatHuggingFace(BaseChatModel):
|
||||
):
|
||||
self.model_id = self.llm.repo_id
|
||||
return
|
||||
elif _is_huggingface_textgen_inference(self.llm):
|
||||
if _is_huggingface_textgen_inference(self.llm):
|
||||
endpoint_url: Optional[str] = self.llm.inference_server_url
|
||||
elif _is_huggingface_pipeline(self.llm):
|
||||
if _is_huggingface_pipeline(self.llm):
|
||||
from transformers import AutoTokenizer # type: ignore[import]
|
||||
|
||||
self.model_id = self.model_id or self.llm.model_id
|
||||
@ -782,29 +782,29 @@ class ChatHuggingFace(BaseChatModel):
|
||||
else self.tokenizer
|
||||
)
|
||||
return
|
||||
elif _is_huggingface_endpoint(self.llm):
|
||||
if _is_huggingface_endpoint(self.llm):
|
||||
self.model_id = self.llm.repo_id or self.llm.model
|
||||
return
|
||||
else:
|
||||
endpoint_url = self.llm.endpoint_url
|
||||
endpoint_url = self.llm.endpoint_url
|
||||
available_endpoints = list_inference_endpoints("*")
|
||||
for endpoint in available_endpoints:
|
||||
if endpoint.url == endpoint_url:
|
||||
self.model_id = endpoint.repository
|
||||
|
||||
if not self.model_id:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"Failed to resolve model_id:"
|
||||
f"Could not find model id for inference server: {endpoint_url}"
|
||||
"Make sure that your Hugging Face token has access to the endpoint."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[
|
||||
Union[dict, str, Literal["auto", "none", "required"], bool]
|
||||
Union[dict, str, Literal["auto", "none", "required"], bool] # noqa: PYI051
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
@ -823,15 +823,16 @@ class ChatHuggingFace(BaseChatModel):
|
||||
{"type": "function", "function": {"name": <<tool_name>>}}.
|
||||
**kwargs: Any additional parameters to pass to the
|
||||
:class:`~langchain.runnable.Runnable` constructor.
|
||||
"""
|
||||
|
||||
"""
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
if tool_choice is not None and tool_choice:
|
||||
if len(formatted_tools) != 1:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"When specifying `tool_choice`, you must provide exactly one "
|
||||
f"tool. Received {len(formatted_tools)} tools."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if isinstance(tool_choice, str):
|
||||
if tool_choice not in ("auto", "none", "required"):
|
||||
tool_choice = {
|
||||
@ -845,15 +846,17 @@ class ChatHuggingFace(BaseChatModel):
|
||||
formatted_tools[0]["function"]["name"]
|
||||
!= tool_choice["function"]["name"]
|
||||
):
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Tool choice {tool_choice} was specified, but the only "
|
||||
f"provided tool was {formatted_tools[0]['function']['name']}."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Unrecognized tool_choice type. Expected str, bool or dict. "
|
||||
f"Received: {tool_choice}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
kwargs["tool_choice"] = tool_choice
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
||||
@ -892,6 +895,11 @@ class ChatHuggingFace(BaseChatModel):
|
||||
will be caught and returned as well. The final output is always a dict
|
||||
with keys "raw", "parsed", and "parsing_error".
|
||||
|
||||
kwargs:
|
||||
Additional parameters to pass to the underlying LLM's
|
||||
:meth:`langchain_core.language_models.chat.BaseChatModel.bind`
|
||||
method, such as `response_format` or `ls_structured_output_format`.
|
||||
|
||||
Returns:
|
||||
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
|
||||
|
||||
@ -908,14 +916,16 @@ class ChatHuggingFace(BaseChatModel):
|
||||
""" # noqa: E501
|
||||
_ = kwargs.pop("strict", None)
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
msg = f"Received unsupported arguments {kwargs}"
|
||||
raise ValueError(msg)
|
||||
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
|
||||
if method == "function_calling":
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"schema must be specified when method is 'function_calling'. "
|
||||
"Received None."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
formatted_tool = convert_to_openai_tool(schema)
|
||||
tool_name = formatted_tool["function"]["name"]
|
||||
llm = self.bind_tools(
|
||||
@ -927,19 +937,18 @@ class ChatHuggingFace(BaseChatModel):
|
||||
},
|
||||
)
|
||||
if is_pydantic_schema:
|
||||
raise NotImplementedError(
|
||||
"Pydantic schema is not supported for function calling"
|
||||
)
|
||||
else:
|
||||
output_parser: Union[JsonOutputKeyToolsParser, JsonOutputParser] = (
|
||||
JsonOutputKeyToolsParser(key_name=tool_name, first_tool_only=True)
|
||||
)
|
||||
msg = "Pydantic schema is not supported for function calling"
|
||||
raise NotImplementedError(msg)
|
||||
output_parser: Union[JsonOutputKeyToolsParser, JsonOutputParser] = (
|
||||
JsonOutputKeyToolsParser(key_name=tool_name, first_tool_only=True)
|
||||
)
|
||||
elif method == "json_schema":
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"schema must be specified when method is 'json_schema'. "
|
||||
"Received None."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
formatted_schema = convert_to_json_schema(schema)
|
||||
llm = self.bind(
|
||||
response_format={"type": "json_object", "schema": formatted_schema},
|
||||
@ -963,10 +972,11 @@ class ChatHuggingFace(BaseChatModel):
|
||||
JsonOutputKeyToolsParser, JsonOutputParser
|
||||
] = JsonOutputParser() # type: ignore[arg-type]
|
||||
else:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Unrecognized method argument. Expected one of 'function_calling' or "
|
||||
f"'json_mode'. Received: '{method}'"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if include_raw:
|
||||
parser_assign = RunnablePassthrough.assign(
|
||||
@ -977,8 +987,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
[parser_none], exception_key="parsing_error"
|
||||
)
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
else:
|
||||
return llm | output_parser
|
||||
return llm | output_parser
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: list[BaseMessage], stop: Optional[list[str]]
|
||||
@ -991,8 +1000,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
|
||||
@property
|
||||
def _default_params(self) -> dict[str, Any]:
|
||||
"""Get the default parameters for calling Hugging Face
|
||||
Inference Providers API."""
|
||||
"""Get default parameters for calling Hugging Face Inference Providers API."""
|
||||
params = {
|
||||
"model": self.model_id,
|
||||
"stream": self.streaming,
|
||||
|
@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from ..utils.import_utils import (
|
||||
from langchain_huggingface.utils.import_utils import (
|
||||
IMPORT_ERROR,
|
||||
is_ipex_available,
|
||||
is_optimum_intel_available,
|
||||
@ -33,12 +35,13 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
model_kwargs=model_kwargs,
|
||||
encode_kwargs=encode_kwargs
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
model_name: str = Field(default=DEFAULT_MODEL_NAME, alias="model")
|
||||
"""Model name to use."""
|
||||
cache_folder: Optional[str] = None
|
||||
"""Path to store models.
|
||||
"""Path to store models.
|
||||
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass to the Sentence Transformer model, such as `device`,
|
||||
@ -46,12 +49,12 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer"""
|
||||
encode_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass when calling the `encode` method for the documents of
|
||||
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
|
||||
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
|
||||
`precision`, `normalize_embeddings`, and more.
|
||||
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""
|
||||
query_encode_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass when calling the `encode` method for the query of
|
||||
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
|
||||
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
|
||||
`precision`, `normalize_embeddings`, and more.
|
||||
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""
|
||||
multi_process: bool = False
|
||||
@ -65,24 +68,25 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
try:
|
||||
import sentence_transformers # type: ignore[import]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
msg = (
|
||||
"Could not import sentence_transformers python package. "
|
||||
"Please install it with `pip install sentence-transformers`."
|
||||
) from exc
|
||||
)
|
||||
raise ImportError(msg) from exc
|
||||
|
||||
if self.model_kwargs.get("backend", "torch") == "ipex":
|
||||
if not is_optimum_intel_available() or not is_ipex_available():
|
||||
raise ImportError(
|
||||
f'Backend: ipex {IMPORT_ERROR.format("optimum[ipex]")}'
|
||||
)
|
||||
msg = f'Backend: ipex {IMPORT_ERROR.format("optimum[ipex]")}'
|
||||
raise ImportError(msg)
|
||||
|
||||
if is_optimum_intel_version("<", _MIN_OPTIMUM_VERSION):
|
||||
raise ImportError(
|
||||
msg = (
|
||||
f"Backend: ipex requires optimum-intel>="
|
||||
f"{_MIN_OPTIMUM_VERSION}. You can install it with pip: "
|
||||
"`pip install --upgrade --upgrade-strategy eager "
|
||||
"`optimum[ipex]`."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
from optimum.intel import IPEXSentenceTransformer # type: ignore[import]
|
||||
|
||||
@ -104,21 +108,21 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
def _embed(
|
||||
self, texts: list[str], encode_kwargs: dict[str, Any]
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Embed a text using the HuggingFace transformer model.
|
||||
"""Embed a text using the HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
encode_kwargs: Keyword arguments to pass when calling the
|
||||
`encode` method for the documents of the SentenceTransformer
|
||||
encode method.
|
||||
encode method.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
|
||||
"""
|
||||
import sentence_transformers # type: ignore[import]
|
||||
|
||||
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
||||
texts = [x.replace("\n", " ") for x in texts]
|
||||
if self.multi_process:
|
||||
pool = self._client.start_multi_process_pool()
|
||||
embeddings = self._client.encode_multi_process(texts, pool)
|
||||
@ -127,16 +131,17 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
embeddings = self._client.encode(
|
||||
texts,
|
||||
show_progress_bar=self.show_progress,
|
||||
**encode_kwargs, # type: ignore
|
||||
**encode_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(embeddings, list):
|
||||
raise TypeError(
|
||||
msg = (
|
||||
"Expected embeddings to be a Tensor or a numpy array, "
|
||||
"got a list instead."
|
||||
)
|
||||
raise TypeError(msg)
|
||||
|
||||
return embeddings.tolist()
|
||||
return embeddings.tolist() # type: ignore[return-type]
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace transformer model.
|
||||
@ -146,6 +151,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
|
||||
"""
|
||||
return self._embed(texts, self.encode_kwargs)
|
||||
|
||||
@ -157,6 +163,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
|
||||
"""
|
||||
embed_kwargs = (
|
||||
self.query_encode_kwargs
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -27,6 +29,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
task="feature-extraction",
|
||||
huggingfacehub_api_token="my-api-key",
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
client: Any = None #: :meta private:
|
||||
@ -35,7 +38,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
"""Model name to use."""
|
||||
provider: Optional[str] = None
|
||||
"""Name of the provider to use for inference with the model specified in
|
||||
``repo_id``. e.g. "sambanova". if not specified, defaults to HF Inference API.
|
||||
``repo_id``. e.g. "sambanova". if not specified, defaults to HF Inference API.
|
||||
available providers can be found in the [huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/guides/inference#supported-providers-and-tasks)."""
|
||||
repo_id: Optional[str] = None
|
||||
"""Huggingfacehub repository id, for backward compatibility."""
|
||||
@ -87,18 +90,20 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
|
||||
if self.task not in VALID_TASKS:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Got invalid task {self.task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
self.client = client
|
||||
self.async_client = async_client
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
msg = (
|
||||
"Could not import huggingface_hub python package. "
|
||||
"Please install it with `pip install huggingface_hub`."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
return self
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
@ -109,6 +114,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
|
||||
"""
|
||||
# replace newlines, which can negatively affect performance.
|
||||
texts = [text.replace("\n", " ") for text in texts]
|
||||
@ -125,6 +131,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
|
||||
"""
|
||||
# replace newlines, which can negatively affect performance.
|
||||
texts = [text.replace("\n", " ") for text in texts]
|
||||
@ -142,9 +149,9 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
|
||||
"""
|
||||
response = self.embed_documents([text])[0]
|
||||
return response
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
"""Async Call to HuggingFaceHub's embedding endpoint for embedding query text.
|
||||
@ -154,6 +161,6 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
|
||||
"""
|
||||
response = (await self.aembed_documents([text]))[0]
|
||||
return response
|
||||
return (await self.aembed_documents([text]))[0]
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
@ -25,8 +27,7 @@ VALID_TASKS = (
|
||||
|
||||
|
||||
class HuggingFaceEndpoint(LLM):
|
||||
"""
|
||||
Hugging Face Endpoint. This works with any model that supports text generation (i.e. text completion) task.
|
||||
"""Hugging Face Endpoint. This works with any model that supports text generation (i.e. text completion) task.
|
||||
|
||||
To use this class, you should have installed the ``huggingface_hub`` package, and
|
||||
the environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token,
|
||||
@ -75,6 +76,7 @@ class HuggingFaceEndpoint(LLM):
|
||||
huggingfacehub_api_token="my-api-key"
|
||||
)
|
||||
print(llm.invoke("What is Deep Learning?"))
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
endpoint_url: Optional[str] = None
|
||||
@ -147,7 +149,8 @@ class HuggingFaceEndpoint(LLM):
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
msg = f"Found {field_name} supplied twice."
|
||||
raise ValueError(msg)
|
||||
if field_name not in all_required_field_names:
|
||||
logger.warning(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
@ -158,10 +161,11 @@ class HuggingFaceEndpoint(LLM):
|
||||
|
||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
values["model_kwargs"] = extra
|
||||
|
||||
@ -185,18 +189,20 @@ class HuggingFaceEndpoint(LLM):
|
||||
repo_id = values.get("repo_id")
|
||||
|
||||
if sum([bool(model), bool(endpoint_url), bool(repo_id)]) > 1:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"Please specify either a `model` OR an `endpoint_url` OR a `repo_id`,"
|
||||
"not more than one."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
values["model"] = (
|
||||
model or endpoint_url or repo_id or os.environ.get("HF_INFERENCE_ENDPOINT")
|
||||
)
|
||||
if not values["model"]:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"Please specify a `model` or an `endpoint_url` or a `repo_id` for the "
|
||||
"model."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return values
|
||||
|
||||
@model_validator(mode="after")
|
||||
@ -274,12 +280,10 @@ class HuggingFaceEndpoint(LLM):
|
||||
"""Get the identifying parameters."""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
return {
|
||||
**{
|
||||
"endpoint_url": self.endpoint_url,
|
||||
"task": self.task,
|
||||
"provider": self.provider,
|
||||
},
|
||||
**{"model_kwargs": _model_kwargs},
|
||||
"endpoint_url": self.endpoint_url,
|
||||
"task": self.task,
|
||||
"provider": self.provider,
|
||||
"model_kwargs": _model_kwargs,
|
||||
}
|
||||
|
||||
@property
|
||||
@ -310,19 +314,19 @@ class HuggingFaceEndpoint(LLM):
|
||||
):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
else:
|
||||
response_text = self.client.text_generation(
|
||||
prompt=prompt,
|
||||
model=self.model,
|
||||
**invocation_params,
|
||||
)
|
||||
|
||||
# Maybe the generation has stopped at one of the stop sequences:
|
||||
# then we remove this stop sequence from the end of the generated text
|
||||
for stop_seq in invocation_params["stop"]:
|
||||
if response_text[-len(stop_seq) :] == stop_seq:
|
||||
response_text = response_text[: -len(stop_seq)]
|
||||
return response_text
|
||||
response_text = self.client.text_generation(
|
||||
prompt=prompt,
|
||||
model=self.model,
|
||||
**invocation_params,
|
||||
)
|
||||
|
||||
# Maybe the generation has stopped at one of the stop sequences:
|
||||
# then we remove this stop sequence from the end of the generated text
|
||||
for stop_seq in invocation_params["stop"]:
|
||||
if response_text[-len(stop_seq) :] == stop_seq:
|
||||
response_text = response_text[: -len(stop_seq)]
|
||||
return response_text
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
@ -339,20 +343,20 @@ class HuggingFaceEndpoint(LLM):
|
||||
):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
else:
|
||||
response_text = await self.async_client.text_generation(
|
||||
prompt=prompt,
|
||||
**invocation_params,
|
||||
model=self.model,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Maybe the generation has stopped at one of the stop sequences:
|
||||
# then remove this stop sequence from the end of the generated text
|
||||
for stop_seq in invocation_params["stop"]:
|
||||
if response_text[-len(stop_seq) :] == stop_seq:
|
||||
response_text = response_text[: -len(stop_seq)]
|
||||
return response_text
|
||||
response_text = await self.async_client.text_generation(
|
||||
prompt=prompt,
|
||||
**invocation_params,
|
||||
model=self.model,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Maybe the generation has stopped at one of the stop sequences:
|
||||
# then remove this stop sequence from the end of the generated text
|
||||
for stop_seq in invocation_params["stop"]:
|
||||
if response_text[-len(stop_seq) :] == stop_seq:
|
||||
response_text = response_text[: -len(stop_seq)]
|
||||
return response_text
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
|
@ -10,7 +10,7 @@ from langchain_core.language_models.llms import BaseLLM
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||
from pydantic import ConfigDict, model_validator
|
||||
|
||||
from ..utils.import_utils import (
|
||||
from langchain_huggingface.utils.import_utils import (
|
||||
IMPORT_ERROR,
|
||||
is_ipex_available,
|
||||
is_openvino_available,
|
||||
@ -87,7 +87,7 @@ class HuggingFacePipeline(BaseLLM):
|
||||
def pre_init_validator(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Ensure model_id is set either by pipeline or user input."""
|
||||
if "model_id" not in values:
|
||||
if "pipeline" in values and values["pipeline"]:
|
||||
if values.get("pipeline"):
|
||||
values["model_id"] = values["pipeline"].model.name_or_path
|
||||
else:
|
||||
values["model_id"] = DEFAULT_MODEL_ID
|
||||
@ -116,34 +116,38 @@ class HuggingFacePipeline(BaseLLM):
|
||||
from transformers import pipeline as hf_pipeline # type: ignore[import]
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"Could not import transformers python package. "
|
||||
"Please install it with `pip install transformers`."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
_model_kwargs = model_kwargs.copy() if model_kwargs else {}
|
||||
if device_map is not None:
|
||||
if device is not None:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"Both `device` and `device_map` are specified. "
|
||||
"`device` will override `device_map`. "
|
||||
"You will most likely encounter unexpected behavior."
|
||||
"Please remove `device` and keep "
|
||||
"`device_map`."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if "device_map" in _model_kwargs:
|
||||
raise ValueError("`device_map` is already specified in `model_kwargs`.")
|
||||
msg = "`device_map` is already specified in `model_kwargs`."
|
||||
raise ValueError(msg)
|
||||
|
||||
_model_kwargs["device_map"] = device_map
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
|
||||
|
||||
if backend in {"openvino", "ipex"}:
|
||||
if task not in VALID_TASKS:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Got invalid task {task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
err_msg = f'Backend: {backend} {IMPORT_ERROR.format(f"optimum[{backend}]")}'
|
||||
if not is_optimum_intel_available():
|
||||
@ -156,12 +160,13 @@ class HuggingFacePipeline(BaseLLM):
|
||||
else _MIN_OPTIMUM_VERSION
|
||||
)
|
||||
if is_optimum_intel_version("<", min_optimum_version):
|
||||
raise ImportError(
|
||||
msg = (
|
||||
f"Backend: {backend} requires optimum-intel>="
|
||||
f"{min_optimum_version}. You can install it with pip: "
|
||||
"`pip install --upgrade --upgrade-strategy eager "
|
||||
f"`optimum[{backend}]`."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
if backend == "openvino":
|
||||
if not is_openvino_available():
|
||||
@ -240,10 +245,11 @@ class HuggingFacePipeline(BaseLLM):
|
||||
|
||||
cuda_device_count = torch.cuda.device_count()
|
||||
if device < -1 or (device >= cuda_device_count):
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Got device=={device}, "
|
||||
f"device is required to be within [-1, {cuda_device_count})"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if device_map is not None and device < 0:
|
||||
device = None
|
||||
if device is not None and device < 0 and cuda_device_count > 0:
|
||||
@ -271,10 +277,11 @@ class HuggingFacePipeline(BaseLLM):
|
||||
**_pipeline_kwargs,
|
||||
)
|
||||
if pipeline.task not in VALID_TASKS:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Got invalid task {pipeline.task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return cls(
|
||||
pipeline=pipeline,
|
||||
model_id=model_id,
|
||||
@ -324,21 +331,22 @@ class HuggingFacePipeline(BaseLLM):
|
||||
# if model returns multiple generations, pick the top one
|
||||
response = response[0]
|
||||
|
||||
if self.pipeline.task == "text-generation":
|
||||
text = response["generated_text"]
|
||||
elif self.pipeline.task == "text2text-generation":
|
||||
text = response["generated_text"]
|
||||
elif self.pipeline.task == "image-text-to-text":
|
||||
if (
|
||||
self.pipeline.task == "text-generation"
|
||||
or self.pipeline.task == "text2text-generation"
|
||||
or self.pipeline.task == "image-text-to-text"
|
||||
):
|
||||
text = response["generated_text"]
|
||||
elif self.pipeline.task == "summarization":
|
||||
text = response["summary_text"]
|
||||
elif self.pipeline.task in "translation":
|
||||
text = response["translation_text"]
|
||||
else:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Got invalid task {self.pipeline.task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if skip_prompt:
|
||||
text = text[len(batch_prompts[j]) :]
|
||||
# Append the processed text to results
|
||||
@ -378,10 +386,7 @@ class HuggingFacePipeline(BaseLLM):
|
||||
scores: torch.FloatTensor,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
for stop_id in stopping_ids_list:
|
||||
if input_ids[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
return any(input_ids[0][-1] == stop_id for stop_id in stopping_ids_list)
|
||||
|
||||
stopping_criteria = StoppingCriteriaList([StopOnTokens()])
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
import operator as op
|
||||
@ -46,8 +48,7 @@ def compare_versions(
|
||||
operation: str,
|
||||
requirement_version: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Compare a library version to some requirement using a given operation.
|
||||
"""Compare a library version to some requirement using a given operation.
|
||||
|
||||
Arguments:
|
||||
library_or_version (`str` or `packaging.version.Version`):
|
||||
@ -56,12 +57,14 @@ def compare_versions(
|
||||
A string representation of an operator, such as `">"` or `"<="`.
|
||||
requirement_version (`str`):
|
||||
The version to compare the library version against
|
||||
|
||||
"""
|
||||
if operation not in STR_OPERATION_TO_FUNC.keys():
|
||||
raise ValueError(
|
||||
if operation not in STR_OPERATION_TO_FUNC:
|
||||
msg = (
|
||||
f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}"
|
||||
f", received {operation}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if isinstance(library_or_version, str):
|
||||
library_or_version = version.parse(
|
||||
importlib.metadata.version(library_or_version)
|
||||
@ -88,9 +91,7 @@ def is_openvino_available() -> bool:
|
||||
|
||||
|
||||
def is_optimum_version(operation: str, reference_version: str) -> bool:
|
||||
"""
|
||||
Compare the current Optimum version to a given reference with an operation.
|
||||
"""
|
||||
"""Compare the current Optimum version to a given reference with an operation."""
|
||||
if not _optimum_version:
|
||||
return False
|
||||
return compare_versions(
|
||||
@ -99,9 +100,7 @@ def is_optimum_version(operation: str, reference_version: str) -> bool:
|
||||
|
||||
|
||||
def is_optimum_intel_version(operation: str, reference_version: str) -> bool:
|
||||
"""
|
||||
Compare the current Optimum Intel version to a given reference with an operation.
|
||||
"""
|
||||
"""Compare current Optimum Intel version to a given reference with an operation."""
|
||||
if not _optimum_intel_version:
|
||||
return False
|
||||
return compare_versions(
|
||||
|
@ -58,8 +58,60 @@ disallow_untyped_defs = "True"
|
||||
target-version = "py39"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "T201", "UP", "S"]
|
||||
ignore = ["UP007"]
|
||||
select = [
|
||||
"A", # flake8-builtins
|
||||
"ASYNC", # flake8-async
|
||||
"C4", # flake8-comprehensions
|
||||
"COM", # flake8-commas
|
||||
"D", # pydocstyle
|
||||
"DOC", # pydoclint
|
||||
"E", # pycodestyle error
|
||||
"EM", # flake8-errmsg
|
||||
"F", # pyflakes
|
||||
"FA", # flake8-future-annotations
|
||||
"FBT", # flake8-boolean-trap
|
||||
"FLY", # flake8-flynt
|
||||
"I", # isort
|
||||
"ICN", # flake8-import-conventions
|
||||
"INT", # flake8-gettext
|
||||
"ISC", # isort-comprehensions
|
||||
"PGH", # pygrep-hooks
|
||||
"PIE", # flake8-pie
|
||||
"PERF", # flake8-perf
|
||||
"PYI", # flake8-pyi
|
||||
"Q", # flake8-quotes
|
||||
"RET", # flake8-return
|
||||
"RSE", # flake8-rst-docstrings
|
||||
"RUF", # ruff
|
||||
"S", # flake8-bandit
|
||||
"SLF", # flake8-self
|
||||
"SLOT", # flake8-slots
|
||||
"SIM", # flake8-simplify
|
||||
"T10", # flake8-debugger
|
||||
"T20", # flake8-print
|
||||
"TID", # flake8-tidy-imports
|
||||
"UP", # pyupgrade
|
||||
"W", # pycodestyle warning
|
||||
"YTT", # flake8-2020
|
||||
]
|
||||
ignore = [
|
||||
"D100", # pydocstyle: Missing docstring in public module
|
||||
"D101", # pydocstyle: Missing docstring in public class
|
||||
"D102", # pydocstyle: Missing docstring in public method
|
||||
"D103", # pydocstyle: Missing docstring in public function
|
||||
"D104", # pydocstyle: Missing docstring in public package
|
||||
"D105", # pydocstyle: Missing docstring in magic method
|
||||
"D107", # pydocstyle: Missing docstring in __init__
|
||||
"D407", # pydocstyle: Missing-dashed-underline-after-section
|
||||
"COM812", # Messes with the formatter
|
||||
"ISC001", # Messes with the formatter
|
||||
"PERF203", # Rarely useful
|
||||
"S112", # Rarely useful
|
||||
"RUF012", # Doesn't play well with Pydantic
|
||||
"SLF001", # Private member access
|
||||
"UP007", # pyupgrade: non-pep604-annotation-union
|
||||
|
||||
]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
@ -4,4 +4,3 @@ import pytest # type: ignore[import-not-found, import-not-found]
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
"""Standard LangChain interface tests."""
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
@ -31,8 +31,7 @@ def mock_llm() -> Mock:
|
||||
"langchain_huggingface.chat_models.huggingface.ChatHuggingFace._resolve_model_id"
|
||||
)
|
||||
def chat_hugging_face(mock_resolve_id: Any, mock_llm: Any) -> ChatHuggingFace:
|
||||
chat_hf = ChatHuggingFace(llm=mock_llm, tokenizer=MagicMock())
|
||||
return chat_hf
|
||||
return ChatHuggingFace(llm=mock_llm, tokenizer=MagicMock())
|
||||
|
||||
|
||||
def test_create_chat_result(chat_hugging_face: Any) -> None:
|
||||
|
@ -6,8 +6,7 @@ DEFAULT_MODEL_ID = "gpt2"
|
||||
|
||||
|
||||
def test_initialization_default() -> None:
|
||||
"""Test default initialization"""
|
||||
|
||||
"""Test default initialization."""
|
||||
llm = HuggingFacePipeline()
|
||||
|
||||
assert llm.model_id == DEFAULT_MODEL_ID
|
||||
@ -15,8 +14,7 @@ def test_initialization_default() -> None:
|
||||
|
||||
@patch("transformers.pipeline")
|
||||
def test_initialization_with_pipeline(mock_pipeline: MagicMock) -> None:
|
||||
"""Test initialization with a pipeline object"""
|
||||
|
||||
"""Test initialization with a pipeline object."""
|
||||
mock_pipe = MagicMock()
|
||||
mock_pipe.model.name_or_path = "mock-model-id"
|
||||
mock_pipeline.return_value = mock_pipe
|
||||
@ -32,8 +30,7 @@ def test_initialization_with_pipeline(mock_pipeline: MagicMock) -> None:
|
||||
def test_initialization_with_from_model_id(
|
||||
mock_pipeline: MagicMock, mock_model: MagicMock, mock_tokenizer: MagicMock
|
||||
) -> None:
|
||||
"""Test initialization with the from_model_id method"""
|
||||
|
||||
"""Test initialization with the from_model_id method."""
|
||||
mock_tokenizer.return_value = MagicMock(pad_token_id=0)
|
||||
mock_model.return_value = MagicMock()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user