mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-13 00:16:01 +00:00
huggingface[patch]: ruff fixes and rules (#31912)
* bump ruff deps * add more thorough ruff rules * fix said rules
This commit is contained in:
parent
06ab2972e3
commit
750721b4c3
@ -87,7 +87,16 @@ select = [
|
|||||||
"W", # pycodestyle warning
|
"W", # pycodestyle warning
|
||||||
"YTT", # flake8-2020
|
"YTT", # flake8-2020
|
||||||
]
|
]
|
||||||
ignore = ["D100", "D101", "D102", "D103", "D104", "D105", "D107"]
|
ignore = [
|
||||||
|
"D100", # pydocstyle: Missing docstring in public module
|
||||||
|
"D101", # pydocstyle: Missing docstring in public class
|
||||||
|
"D102", # pydocstyle: Missing docstring in public method
|
||||||
|
"D103", # pydocstyle: Missing docstring in public function
|
||||||
|
"D104", # pydocstyle: Missing docstring in public package
|
||||||
|
"D105", # pydocstyle: Missing docstring in magic method
|
||||||
|
"D107", # pydocstyle: Missing docstring in __init__
|
||||||
|
"D407", # pydocstyle: Missing-dashed-underline-after-section
|
||||||
|
]
|
||||||
pyupgrade.keep-runtime-typing = true
|
pyupgrade.keep-runtime-typing = true
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
|
@ -180,13 +180,14 @@ select = [
|
|||||||
"YTT", # flake8-2020
|
"YTT", # flake8-2020
|
||||||
]
|
]
|
||||||
ignore = [
|
ignore = [
|
||||||
"D100", # Missing docstring in public module
|
"D100", # pydocstyle: Missing docstring in public module
|
||||||
"D101", # Missing docstring in public class
|
"D101", # pydocstyle: Missing docstring in public class
|
||||||
"D102", # Missing docstring in public method
|
"D102", # pydocstyle: Missing docstring in public method
|
||||||
"D103", # Missing docstring in public function
|
"D103", # pydocstyle: Missing docstring in public function
|
||||||
"D104", # Missing docstring in public package
|
"D104", # pydocstyle: Missing docstring in public package
|
||||||
"D105", # Missing docstring in magic method
|
"D105", # pydocstyle: Missing docstring in magic method
|
||||||
"D107", # Missing docstring in __init__
|
"D107", # pydocstyle: Missing docstring in __init__
|
||||||
|
"D407", # pydocstyle: Missing-dashed-underline-after-section
|
||||||
"COM812", # Messes with the formatter
|
"COM812", # Messes with the formatter
|
||||||
"ISC001", # Messes with the formatter
|
"ISC001", # Messes with the formatter
|
||||||
"PERF203", # Rarely useful
|
"PERF203", # Rarely useful
|
||||||
|
@ -97,13 +97,14 @@ select = [
|
|||||||
"YTT", # flake8-2020
|
"YTT", # flake8-2020
|
||||||
]
|
]
|
||||||
ignore = [
|
ignore = [
|
||||||
"D100", # Missing docstring in public module
|
"D100", # pydocstyle: Missing docstring in public module
|
||||||
"D101", # Missing docstring in public class
|
"D101", # pydocstyle: Missing docstring in public class
|
||||||
"D102", # Missing docstring in public method
|
"D102", # pydocstyle: Missing docstring in public method
|
||||||
"D103", # Missing docstring in public function
|
"D103", # pydocstyle: Missing docstring in public function
|
||||||
"D104", # Missing docstring in public package
|
"D104", # pydocstyle: Missing docstring in public package
|
||||||
"D105", # Missing docstring in magic method
|
"D105", # pydocstyle: Missing docstring in magic method
|
||||||
"D107", # Missing docstring in __init__
|
"D107", # pydocstyle: Missing docstring in __init__
|
||||||
|
"D407", # pydocstyle: Missing-dashed-underline-after-section
|
||||||
"D214", # Section over-indented, doesn't play well with reStructuredText
|
"D214", # Section over-indented, doesn't play well with reStructuredText
|
||||||
"COM812", # Messes with the formatter
|
"COM812", # Messes with the formatter
|
||||||
"ISC001", # Messes with the formatter
|
"ISC001", # Messes with the formatter
|
||||||
|
@ -343,7 +343,6 @@ def function() -> Callable:
|
|||||||
"""Dummy function.
|
"""Dummy function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
----
|
|
||||||
arg1: foo
|
arg1: foo
|
||||||
arg2: one of 'bar', 'baz'
|
arg2: one of 'bar', 'baz'
|
||||||
|
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
"""Hugging Face Chat Wrapper."""
|
"""Hugging Face Chat Wrapper."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import json
|
import json
|
||||||
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
|
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -58,8 +61,8 @@ from langchain_core.utils.pydantic import is_basemodel_subclass
|
|||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from ..llms.huggingface_endpoint import HuggingFaceEndpoint
|
from langchain_huggingface.llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||||
from ..llms.huggingface_pipeline import HuggingFacePipeline
|
from langchain_huggingface.llms.huggingface_pipeline import HuggingFacePipeline
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -111,6 +114,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The dictionary.
|
The dictionary.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
message_dict: dict[str, Any]
|
message_dict: dict[str, Any]
|
||||||
if isinstance(message, ChatMessage):
|
if isinstance(message, ChatMessage):
|
||||||
@ -153,7 +157,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|||||||
"tool_call_id": message.tool_call_id,
|
"tool_call_id": message.tool_call_id,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Got unknown type {message}")
|
msg = f"Got unknown type {message}"
|
||||||
|
raise TypeError(msg)
|
||||||
if "name" in message.additional_kwargs:
|
if "name" in message.additional_kwargs:
|
||||||
message_dict["name"] = message.additional_kwargs["name"]
|
message_dict["name"] = message.additional_kwargs["name"]
|
||||||
return message_dict
|
return message_dict
|
||||||
@ -167,11 +172,12 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The LangChain message.
|
The LangChain message.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
role = _dict.get("role")
|
role = _dict.get("role")
|
||||||
if role == "user":
|
if role == "user":
|
||||||
return HumanMessage(content=_dict.get("content", ""))
|
return HumanMessage(content=_dict.get("content", ""))
|
||||||
elif role == "assistant":
|
if role == "assistant":
|
||||||
content = _dict.get("content", "") or ""
|
content = _dict.get("content", "") or ""
|
||||||
additional_kwargs: dict = {}
|
additional_kwargs: dict = {}
|
||||||
if function_call := _dict.get("function_call"):
|
if function_call := _dict.get("function_call"):
|
||||||
@ -193,13 +199,13 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
invalid_tool_calls=invalid_tool_calls,
|
invalid_tool_calls=invalid_tool_calls,
|
||||||
)
|
)
|
||||||
elif role == "system":
|
if role == "system":
|
||||||
return SystemMessage(content=_dict.get("content", ""))
|
return SystemMessage(content=_dict.get("content", ""))
|
||||||
elif role == "function":
|
if role == "function":
|
||||||
return FunctionMessage(
|
return FunctionMessage(
|
||||||
content=_dict.get("content", ""), name=_dict.get("name", "")
|
content=_dict.get("content", ""), name=_dict.get("name", "")
|
||||||
)
|
)
|
||||||
elif role == "tool":
|
if role == "tool":
|
||||||
additional_kwargs = {}
|
additional_kwargs = {}
|
||||||
if "name" in _dict:
|
if "name" in _dict:
|
||||||
additional_kwargs["name"] = _dict["name"]
|
additional_kwargs["name"] = _dict["name"]
|
||||||
@ -208,8 +214,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|||||||
tool_call_id=_dict.get("tool_call_id", ""),
|
tool_call_id=_dict.get("tool_call_id", ""),
|
||||||
additional_kwargs=additional_kwargs,
|
additional_kwargs=additional_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
return ChatMessage(content=_dict.get("content", ""), role=role or "")
|
||||||
return ChatMessage(content=_dict.get("content", ""), role=role or "")
|
|
||||||
|
|
||||||
|
|
||||||
def _is_huggingface_hub(llm: Any) -> bool:
|
def _is_huggingface_hub(llm: Any) -> bool:
|
||||||
@ -241,7 +246,7 @@ def _convert_chunk_to_message_chunk(
|
|||||||
if raw_tool_calls := _dict.get("tool_calls"):
|
if raw_tool_calls := _dict.get("tool_calls"):
|
||||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||||
for rtc in raw_tool_calls:
|
for rtc in raw_tool_calls:
|
||||||
try:
|
with contextlib.suppress(KeyError):
|
||||||
tool_call_chunks.append(
|
tool_call_chunks.append(
|
||||||
create_tool_call_chunk(
|
create_tool_call_chunk(
|
||||||
name=rtc["function"].get("name"),
|
name=rtc["function"].get("name"),
|
||||||
@ -250,11 +255,9 @@ def _convert_chunk_to_message_chunk(
|
|||||||
index=rtc.get("index"),
|
index=rtc.get("index"),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
if role == "user" or default_class == HumanMessageChunk:
|
if role == "user" or default_class == HumanMessageChunk:
|
||||||
return HumanMessageChunk(content=content)
|
return HumanMessageChunk(content=content)
|
||||||
elif role == "assistant" or default_class == AIMessageChunk:
|
if role == "assistant" or default_class == AIMessageChunk:
|
||||||
if usage := chunk.get("usage"):
|
if usage := chunk.get("usage"):
|
||||||
input_tokens = usage.get("prompt_tokens", 0)
|
input_tokens = usage.get("prompt_tokens", 0)
|
||||||
output_tokens = usage.get("completion_tokens", 0)
|
output_tokens = usage.get("completion_tokens", 0)
|
||||||
@ -271,16 +274,15 @@ def _convert_chunk_to_message_chunk(
|
|||||||
tool_call_chunks=tool_call_chunks,
|
tool_call_chunks=tool_call_chunks,
|
||||||
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
elif role == "system" or default_class == SystemMessageChunk:
|
if role == "system" or default_class == SystemMessageChunk:
|
||||||
return SystemMessageChunk(content=content)
|
return SystemMessageChunk(content=content)
|
||||||
elif role == "function" or default_class == FunctionMessageChunk:
|
if role == "function" or default_class == FunctionMessageChunk:
|
||||||
return FunctionMessageChunk(content=content, name=_dict["name"])
|
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||||
elif role == "tool" or default_class == ToolMessageChunk:
|
if role == "tool" or default_class == ToolMessageChunk:
|
||||||
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
|
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
|
||||||
elif role or default_class == ChatMessageChunk:
|
if role or default_class == ChatMessageChunk:
|
||||||
return ChatMessageChunk(content=content, role=role)
|
return ChatMessageChunk(content=content, role=role)
|
||||||
else:
|
return default_class(content=content) # type: ignore[call-arg]
|
||||||
return default_class(content=content) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def _is_huggingface_textgen_inference(llm: Any) -> bool:
|
def _is_huggingface_textgen_inference(llm: Any) -> bool:
|
||||||
@ -304,7 +306,7 @@ def _is_huggingface_pipeline(llm: Any) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
class ChatHuggingFace(BaseChatModel):
|
class ChatHuggingFace(BaseChatModel):
|
||||||
"""Hugging Face LLM's as ChatModels.
|
r"""Hugging Face LLM's as ChatModels.
|
||||||
|
|
||||||
Works with `HuggingFaceTextGenInference`, `HuggingFaceEndpoint`,
|
Works with `HuggingFaceTextGenInference`, `HuggingFaceEndpoint`,
|
||||||
`HuggingFaceHub`, and `HuggingFacePipeline` LLMs.
|
`HuggingFaceHub`, and `HuggingFacePipeline` LLMs.
|
||||||
@ -484,7 +486,7 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
"""Whether to return logprobs."""
|
"""Whether to return logprobs."""
|
||||||
top_logprobs: Optional[int] = None
|
top_logprobs: Optional[int] = None
|
||||||
"""Number of most likely tokens to return at each token position, each with
|
"""Number of most likely tokens to return at each token position, each with
|
||||||
an associated log probability. `logprobs` must be set to true
|
an associated log probability. `logprobs` must be set to true
|
||||||
if this parameter is used."""
|
if this parameter is used."""
|
||||||
logit_bias: Optional[dict[int, int]] = None
|
logit_bias: Optional[dict[int, int]] = None
|
||||||
"""Modify the likelihood of specified tokens appearing in the completion."""
|
"""Modify the likelihood of specified tokens appearing in the completion."""
|
||||||
@ -511,11 +513,12 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
and not _is_huggingface_endpoint(self.llm)
|
and not _is_huggingface_endpoint(self.llm)
|
||||||
and not _is_huggingface_pipeline(self.llm)
|
and not _is_huggingface_pipeline(self.llm)
|
||||||
):
|
):
|
||||||
raise TypeError(
|
msg = (
|
||||||
"Expected llm to be one of HuggingFaceTextGenInference, "
|
"Expected llm to be one of HuggingFaceTextGenInference, "
|
||||||
"HuggingFaceEndpoint, HuggingFaceHub, HuggingFacePipeline "
|
"HuggingFaceEndpoint, HuggingFaceHub, HuggingFacePipeline "
|
||||||
f"received {type(self.llm)}"
|
f"received {type(self.llm)}"
|
||||||
)
|
)
|
||||||
|
raise TypeError(msg)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _create_chat_result(self, response: dict) -> ChatResult:
|
def _create_chat_result(self, response: dict) -> ChatResult:
|
||||||
@ -529,7 +532,7 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
"output_tokens": token_usage.get("completion_tokens", 0),
|
"output_tokens": token_usage.get("completion_tokens", 0),
|
||||||
"total_tokens": token_usage.get("total_tokens", 0),
|
"total_tokens": token_usage.get("total_tokens", 0),
|
||||||
}
|
}
|
||||||
generation_info = dict(finish_reason=res.get("finish_reason"))
|
generation_info = {"finish_reason": res.get("finish_reason")}
|
||||||
if "logprobs" in res:
|
if "logprobs" in res:
|
||||||
generation_info["logprobs"] = res["logprobs"]
|
generation_info["logprobs"] = res["logprobs"]
|
||||||
gen = ChatGeneration(
|
gen = ChatGeneration(
|
||||||
@ -558,7 +561,7 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
answer = self.llm.client.chat(messages=message_dicts, **kwargs)
|
answer = self.llm.client.chat(messages=message_dicts, **kwargs)
|
||||||
return self._create_chat_result(answer)
|
return self._create_chat_result(answer)
|
||||||
elif _is_huggingface_endpoint(self.llm):
|
if _is_huggingface_endpoint(self.llm):
|
||||||
if should_stream:
|
if should_stream:
|
||||||
stream_iter = self._stream(
|
stream_iter = self._stream(
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
@ -573,18 +576,17 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
}
|
}
|
||||||
answer = self.llm.client.chat_completion(messages=message_dicts, **params)
|
answer = self.llm.client.chat_completion(messages=message_dicts, **params)
|
||||||
return self._create_chat_result(answer)
|
return self._create_chat_result(answer)
|
||||||
else:
|
llm_input = self._to_chat_prompt(messages)
|
||||||
llm_input = self._to_chat_prompt(messages)
|
|
||||||
|
|
||||||
if should_stream:
|
if should_stream:
|
||||||
stream_iter = self.llm._stream(
|
stream_iter = self.llm._stream(
|
||||||
llm_input, stop=stop, run_manager=run_manager, **kwargs
|
llm_input, stop=stop, run_manager=run_manager, **kwargs
|
||||||
)
|
|
||||||
return generate_from_stream(stream_iter)
|
|
||||||
llm_result = self.llm._generate(
|
|
||||||
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
|
|
||||||
)
|
)
|
||||||
return self._to_chat_result(llm_result)
|
return generate_from_stream(stream_iter)
|
||||||
|
llm_result = self.llm._generate(
|
||||||
|
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return self._to_chat_result(llm_result)
|
||||||
|
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
self,
|
self,
|
||||||
@ -598,7 +600,7 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
answer = await self.llm.async_client.chat(messages=message_dicts, **kwargs)
|
answer = await self.llm.async_client.chat(messages=message_dicts, **kwargs)
|
||||||
return self._create_chat_result(answer)
|
return self._create_chat_result(answer)
|
||||||
elif _is_huggingface_endpoint(self.llm):
|
if _is_huggingface_endpoint(self.llm):
|
||||||
should_stream = stream if stream is not None else self.streaming
|
should_stream = stream if stream is not None else self.streaming
|
||||||
if should_stream:
|
if should_stream:
|
||||||
stream_iter = self._astream(
|
stream_iter = self._astream(
|
||||||
@ -616,17 +618,14 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
messages=message_dicts, **params
|
messages=message_dicts, **params
|
||||||
)
|
)
|
||||||
return self._create_chat_result(answer)
|
return self._create_chat_result(answer)
|
||||||
|
if _is_huggingface_pipeline(self.llm):
|
||||||
elif _is_huggingface_pipeline(self.llm):
|
msg = "async generation is not supported with HuggingFacePipeline"
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(msg)
|
||||||
"async generation is not supported with HuggingFacePipeline"
|
llm_input = self._to_chat_prompt(messages)
|
||||||
)
|
llm_result = await self.llm._agenerate(
|
||||||
else:
|
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
|
||||||
llm_input = self._to_chat_prompt(messages)
|
)
|
||||||
llm_result = await self.llm._agenerate(
|
return self._to_chat_result(llm_result)
|
||||||
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
|
|
||||||
)
|
|
||||||
return self._to_chat_result(llm_result)
|
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
@ -721,10 +720,12 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
|
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
|
||||||
if not messages:
|
if not messages:
|
||||||
raise ValueError("At least one HumanMessage must be provided!")
|
msg = "At least one HumanMessage must be provided!"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
if not isinstance(messages[-1], HumanMessage):
|
if not isinstance(messages[-1], HumanMessage):
|
||||||
raise ValueError("Last message must be a HumanMessage!")
|
msg = "Last message must be a HumanMessage!"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
messages_dicts = [self._to_chatml_format(m) for m in messages]
|
messages_dicts = [self._to_chatml_format(m) for m in messages]
|
||||||
|
|
||||||
@ -734,7 +735,6 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
|
|
||||||
def _to_chatml_format(self, message: BaseMessage) -> dict:
|
def _to_chatml_format(self, message: BaseMessage) -> dict:
|
||||||
"""Convert LangChain message to ChatML format."""
|
"""Convert LangChain message to ChatML format."""
|
||||||
|
|
||||||
if isinstance(message, SystemMessage):
|
if isinstance(message, SystemMessage):
|
||||||
role = "system"
|
role = "system"
|
||||||
elif isinstance(message, AIMessage):
|
elif isinstance(message, AIMessage):
|
||||||
@ -742,7 +742,8 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
elif isinstance(message, HumanMessage):
|
elif isinstance(message, HumanMessage):
|
||||||
role = "user"
|
role = "user"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown message type: {type(message)}")
|
msg = f"Unknown message type: {type(message)}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
return {"role": role, "content": message.content}
|
return {"role": role, "content": message.content}
|
||||||
|
|
||||||
@ -761,8 +762,7 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _resolve_model_id(self) -> None:
|
def _resolve_model_id(self) -> None:
|
||||||
"""Resolve the model_id from the LLM's inference_server_url"""
|
"""Resolve the model_id from the LLM's inference_server_url."""
|
||||||
|
|
||||||
from huggingface_hub import list_inference_endpoints # type: ignore[import]
|
from huggingface_hub import list_inference_endpoints # type: ignore[import]
|
||||||
|
|
||||||
if _is_huggingface_hub(self.llm) or (
|
if _is_huggingface_hub(self.llm) or (
|
||||||
@ -770,9 +770,9 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
):
|
):
|
||||||
self.model_id = self.llm.repo_id
|
self.model_id = self.llm.repo_id
|
||||||
return
|
return
|
||||||
elif _is_huggingface_textgen_inference(self.llm):
|
if _is_huggingface_textgen_inference(self.llm):
|
||||||
endpoint_url: Optional[str] = self.llm.inference_server_url
|
endpoint_url: Optional[str] = self.llm.inference_server_url
|
||||||
elif _is_huggingface_pipeline(self.llm):
|
if _is_huggingface_pipeline(self.llm):
|
||||||
from transformers import AutoTokenizer # type: ignore[import]
|
from transformers import AutoTokenizer # type: ignore[import]
|
||||||
|
|
||||||
self.model_id = self.model_id or self.llm.model_id
|
self.model_id = self.model_id or self.llm.model_id
|
||||||
@ -782,29 +782,29 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
else self.tokenizer
|
else self.tokenizer
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
elif _is_huggingface_endpoint(self.llm):
|
if _is_huggingface_endpoint(self.llm):
|
||||||
self.model_id = self.llm.repo_id or self.llm.model
|
self.model_id = self.llm.repo_id or self.llm.model
|
||||||
return
|
return
|
||||||
else:
|
endpoint_url = self.llm.endpoint_url
|
||||||
endpoint_url = self.llm.endpoint_url
|
|
||||||
available_endpoints = list_inference_endpoints("*")
|
available_endpoints = list_inference_endpoints("*")
|
||||||
for endpoint in available_endpoints:
|
for endpoint in available_endpoints:
|
||||||
if endpoint.url == endpoint_url:
|
if endpoint.url == endpoint_url:
|
||||||
self.model_id = endpoint.repository
|
self.model_id = endpoint.repository
|
||||||
|
|
||||||
if not self.model_id:
|
if not self.model_id:
|
||||||
raise ValueError(
|
msg = (
|
||||||
"Failed to resolve model_id:"
|
"Failed to resolve model_id:"
|
||||||
f"Could not find model id for inference server: {endpoint_url}"
|
f"Could not find model id for inference server: {endpoint_url}"
|
||||||
"Make sure that your Hugging Face token has access to the endpoint."
|
"Make sure that your Hugging Face token has access to the endpoint."
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
def bind_tools(
|
def bind_tools(
|
||||||
self,
|
self,
|
||||||
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
|
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
|
||||||
*,
|
*,
|
||||||
tool_choice: Optional[
|
tool_choice: Optional[
|
||||||
Union[dict, str, Literal["auto", "none", "required"], bool]
|
Union[dict, str, Literal["auto", "none", "required"], bool] # noqa: PYI051
|
||||||
] = None,
|
] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||||
@ -823,15 +823,16 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
{"type": "function", "function": {"name": <<tool_name>>}}.
|
{"type": "function", "function": {"name": <<tool_name>>}}.
|
||||||
**kwargs: Any additional parameters to pass to the
|
**kwargs: Any additional parameters to pass to the
|
||||||
:class:`~langchain.runnable.Runnable` constructor.
|
:class:`~langchain.runnable.Runnable` constructor.
|
||||||
"""
|
|
||||||
|
|
||||||
|
"""
|
||||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||||
if tool_choice is not None and tool_choice:
|
if tool_choice is not None and tool_choice:
|
||||||
if len(formatted_tools) != 1:
|
if len(formatted_tools) != 1:
|
||||||
raise ValueError(
|
msg = (
|
||||||
"When specifying `tool_choice`, you must provide exactly one "
|
"When specifying `tool_choice`, you must provide exactly one "
|
||||||
f"tool. Received {len(formatted_tools)} tools."
|
f"tool. Received {len(formatted_tools)} tools."
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
if isinstance(tool_choice, str):
|
if isinstance(tool_choice, str):
|
||||||
if tool_choice not in ("auto", "none", "required"):
|
if tool_choice not in ("auto", "none", "required"):
|
||||||
tool_choice = {
|
tool_choice = {
|
||||||
@ -845,15 +846,17 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
formatted_tools[0]["function"]["name"]
|
formatted_tools[0]["function"]["name"]
|
||||||
!= tool_choice["function"]["name"]
|
!= tool_choice["function"]["name"]
|
||||||
):
|
):
|
||||||
raise ValueError(
|
msg = (
|
||||||
f"Tool choice {tool_choice} was specified, but the only "
|
f"Tool choice {tool_choice} was specified, but the only "
|
||||||
f"provided tool was {formatted_tools[0]['function']['name']}."
|
f"provided tool was {formatted_tools[0]['function']['name']}."
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
msg = (
|
||||||
f"Unrecognized tool_choice type. Expected str, bool or dict. "
|
f"Unrecognized tool_choice type. Expected str, bool or dict. "
|
||||||
f"Received: {tool_choice}"
|
f"Received: {tool_choice}"
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
kwargs["tool_choice"] = tool_choice
|
kwargs["tool_choice"] = tool_choice
|
||||||
return super().bind(tools=formatted_tools, **kwargs)
|
return super().bind(tools=formatted_tools, **kwargs)
|
||||||
|
|
||||||
@ -892,6 +895,11 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
will be caught and returned as well. The final output is always a dict
|
will be caught and returned as well. The final output is always a dict
|
||||||
with keys "raw", "parsed", and "parsing_error".
|
with keys "raw", "parsed", and "parsing_error".
|
||||||
|
|
||||||
|
kwargs:
|
||||||
|
Additional parameters to pass to the underlying LLM's
|
||||||
|
:meth:`langchain_core.language_models.chat.BaseChatModel.bind`
|
||||||
|
method, such as `response_format` or `ls_structured_output_format`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
|
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
|
||||||
|
|
||||||
@ -908,14 +916,16 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
_ = kwargs.pop("strict", None)
|
_ = kwargs.pop("strict", None)
|
||||||
if kwargs:
|
if kwargs:
|
||||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
msg = f"Received unsupported arguments {kwargs}"
|
||||||
|
raise ValueError(msg)
|
||||||
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
|
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
|
||||||
if method == "function_calling":
|
if method == "function_calling":
|
||||||
if schema is None:
|
if schema is None:
|
||||||
raise ValueError(
|
msg = (
|
||||||
"schema must be specified when method is 'function_calling'. "
|
"schema must be specified when method is 'function_calling'. "
|
||||||
"Received None."
|
"Received None."
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
formatted_tool = convert_to_openai_tool(schema)
|
formatted_tool = convert_to_openai_tool(schema)
|
||||||
tool_name = formatted_tool["function"]["name"]
|
tool_name = formatted_tool["function"]["name"]
|
||||||
llm = self.bind_tools(
|
llm = self.bind_tools(
|
||||||
@ -927,19 +937,18 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
if is_pydantic_schema:
|
if is_pydantic_schema:
|
||||||
raise NotImplementedError(
|
msg = "Pydantic schema is not supported for function calling"
|
||||||
"Pydantic schema is not supported for function calling"
|
raise NotImplementedError(msg)
|
||||||
)
|
output_parser: Union[JsonOutputKeyToolsParser, JsonOutputParser] = (
|
||||||
else:
|
JsonOutputKeyToolsParser(key_name=tool_name, first_tool_only=True)
|
||||||
output_parser: Union[JsonOutputKeyToolsParser, JsonOutputParser] = (
|
)
|
||||||
JsonOutputKeyToolsParser(key_name=tool_name, first_tool_only=True)
|
|
||||||
)
|
|
||||||
elif method == "json_schema":
|
elif method == "json_schema":
|
||||||
if schema is None:
|
if schema is None:
|
||||||
raise ValueError(
|
msg = (
|
||||||
"schema must be specified when method is 'json_schema'. "
|
"schema must be specified when method is 'json_schema'. "
|
||||||
"Received None."
|
"Received None."
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
formatted_schema = convert_to_json_schema(schema)
|
formatted_schema = convert_to_json_schema(schema)
|
||||||
llm = self.bind(
|
llm = self.bind(
|
||||||
response_format={"type": "json_object", "schema": formatted_schema},
|
response_format={"type": "json_object", "schema": formatted_schema},
|
||||||
@ -963,10 +972,11 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
JsonOutputKeyToolsParser, JsonOutputParser
|
JsonOutputKeyToolsParser, JsonOutputParser
|
||||||
] = JsonOutputParser() # type: ignore[arg-type]
|
] = JsonOutputParser() # type: ignore[arg-type]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
msg = (
|
||||||
f"Unrecognized method argument. Expected one of 'function_calling' or "
|
f"Unrecognized method argument. Expected one of 'function_calling' or "
|
||||||
f"'json_mode'. Received: '{method}'"
|
f"'json_mode'. Received: '{method}'"
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
if include_raw:
|
if include_raw:
|
||||||
parser_assign = RunnablePassthrough.assign(
|
parser_assign = RunnablePassthrough.assign(
|
||||||
@ -977,8 +987,7 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
[parser_none], exception_key="parsing_error"
|
[parser_none], exception_key="parsing_error"
|
||||||
)
|
)
|
||||||
return RunnableMap(raw=llm) | parser_with_fallback
|
return RunnableMap(raw=llm) | parser_with_fallback
|
||||||
else:
|
return llm | output_parser
|
||||||
return llm | output_parser
|
|
||||||
|
|
||||||
def _create_message_dicts(
|
def _create_message_dicts(
|
||||||
self, messages: list[BaseMessage], stop: Optional[list[str]]
|
self, messages: list[BaseMessage], stop: Optional[list[str]]
|
||||||
@ -991,8 +1000,7 @@ class ChatHuggingFace(BaseChatModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _default_params(self) -> dict[str, Any]:
|
def _default_params(self) -> dict[str, Any]:
|
||||||
"""Get the default parameters for calling Hugging Face
|
"""Get default parameters for calling Hugging Face Inference Providers API."""
|
||||||
Inference Providers API."""
|
|
||||||
params = {
|
params = {
|
||||||
"model": self.model_id,
|
"model": self.model_id,
|
||||||
"stream": self.streaming,
|
"stream": self.streaming,
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from ..utils.import_utils import (
|
from langchain_huggingface.utils.import_utils import (
|
||||||
IMPORT_ERROR,
|
IMPORT_ERROR,
|
||||||
is_ipex_available,
|
is_ipex_available,
|
||||||
is_optimum_intel_available,
|
is_optimum_intel_available,
|
||||||
@ -33,12 +35,13 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
|||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
encode_kwargs=encode_kwargs
|
encode_kwargs=encode_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_name: str = Field(default=DEFAULT_MODEL_NAME, alias="model")
|
model_name: str = Field(default=DEFAULT_MODEL_NAME, alias="model")
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
cache_folder: Optional[str] = None
|
cache_folder: Optional[str] = None
|
||||||
"""Path to store models.
|
"""Path to store models.
|
||||||
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
|
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
|
||||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||||
"""Keyword arguments to pass to the Sentence Transformer model, such as `device`,
|
"""Keyword arguments to pass to the Sentence Transformer model, such as `device`,
|
||||||
@ -46,12 +49,12 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
|||||||
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer"""
|
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer"""
|
||||||
encode_kwargs: dict[str, Any] = Field(default_factory=dict)
|
encode_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||||
"""Keyword arguments to pass when calling the `encode` method for the documents of
|
"""Keyword arguments to pass when calling the `encode` method for the documents of
|
||||||
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
|
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
|
||||||
`precision`, `normalize_embeddings`, and more.
|
`precision`, `normalize_embeddings`, and more.
|
||||||
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""
|
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""
|
||||||
query_encode_kwargs: dict[str, Any] = Field(default_factory=dict)
|
query_encode_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||||
"""Keyword arguments to pass when calling the `encode` method for the query of
|
"""Keyword arguments to pass when calling the `encode` method for the query of
|
||||||
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
|
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
|
||||||
`precision`, `normalize_embeddings`, and more.
|
`precision`, `normalize_embeddings`, and more.
|
||||||
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""
|
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""
|
||||||
multi_process: bool = False
|
multi_process: bool = False
|
||||||
@ -65,24 +68,25 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
|||||||
try:
|
try:
|
||||||
import sentence_transformers # type: ignore[import]
|
import sentence_transformers # type: ignore[import]
|
||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
raise ImportError(
|
msg = (
|
||||||
"Could not import sentence_transformers python package. "
|
"Could not import sentence_transformers python package. "
|
||||||
"Please install it with `pip install sentence-transformers`."
|
"Please install it with `pip install sentence-transformers`."
|
||||||
) from exc
|
)
|
||||||
|
raise ImportError(msg) from exc
|
||||||
|
|
||||||
if self.model_kwargs.get("backend", "torch") == "ipex":
|
if self.model_kwargs.get("backend", "torch") == "ipex":
|
||||||
if not is_optimum_intel_available() or not is_ipex_available():
|
if not is_optimum_intel_available() or not is_ipex_available():
|
||||||
raise ImportError(
|
msg = f'Backend: ipex {IMPORT_ERROR.format("optimum[ipex]")}'
|
||||||
f'Backend: ipex {IMPORT_ERROR.format("optimum[ipex]")}'
|
raise ImportError(msg)
|
||||||
)
|
|
||||||
|
|
||||||
if is_optimum_intel_version("<", _MIN_OPTIMUM_VERSION):
|
if is_optimum_intel_version("<", _MIN_OPTIMUM_VERSION):
|
||||||
raise ImportError(
|
msg = (
|
||||||
f"Backend: ipex requires optimum-intel>="
|
f"Backend: ipex requires optimum-intel>="
|
||||||
f"{_MIN_OPTIMUM_VERSION}. You can install it with pip: "
|
f"{_MIN_OPTIMUM_VERSION}. You can install it with pip: "
|
||||||
"`pip install --upgrade --upgrade-strategy eager "
|
"`pip install --upgrade --upgrade-strategy eager "
|
||||||
"`optimum[ipex]`."
|
"`optimum[ipex]`."
|
||||||
)
|
)
|
||||||
|
raise ImportError(msg)
|
||||||
|
|
||||||
from optimum.intel import IPEXSentenceTransformer # type: ignore[import]
|
from optimum.intel import IPEXSentenceTransformer # type: ignore[import]
|
||||||
|
|
||||||
@ -104,21 +108,21 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
|||||||
def _embed(
|
def _embed(
|
||||||
self, texts: list[str], encode_kwargs: dict[str, Any]
|
self, texts: list[str], encode_kwargs: dict[str, Any]
|
||||||
) -> list[list[float]]:
|
) -> list[list[float]]:
|
||||||
"""
|
"""Embed a text using the HuggingFace transformer model.
|
||||||
Embed a text using the HuggingFace transformer model.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
texts: The list of texts to embed.
|
texts: The list of texts to embed.
|
||||||
encode_kwargs: Keyword arguments to pass when calling the
|
encode_kwargs: Keyword arguments to pass when calling the
|
||||||
`encode` method for the documents of the SentenceTransformer
|
`encode` method for the documents of the SentenceTransformer
|
||||||
encode method.
|
encode method.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import sentence_transformers # type: ignore[import]
|
import sentence_transformers # type: ignore[import]
|
||||||
|
|
||||||
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
texts = [x.replace("\n", " ") for x in texts]
|
||||||
if self.multi_process:
|
if self.multi_process:
|
||||||
pool = self._client.start_multi_process_pool()
|
pool = self._client.start_multi_process_pool()
|
||||||
embeddings = self._client.encode_multi_process(texts, pool)
|
embeddings = self._client.encode_multi_process(texts, pool)
|
||||||
@ -127,16 +131,17 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
|||||||
embeddings = self._client.encode(
|
embeddings = self._client.encode(
|
||||||
texts,
|
texts,
|
||||||
show_progress_bar=self.show_progress,
|
show_progress_bar=self.show_progress,
|
||||||
**encode_kwargs, # type: ignore
|
**encode_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(embeddings, list):
|
if isinstance(embeddings, list):
|
||||||
raise TypeError(
|
msg = (
|
||||||
"Expected embeddings to be a Tensor or a numpy array, "
|
"Expected embeddings to be a Tensor or a numpy array, "
|
||||||
"got a list instead."
|
"got a list instead."
|
||||||
)
|
)
|
||||||
|
raise TypeError(msg)
|
||||||
|
|
||||||
return embeddings.tolist()
|
return embeddings.tolist() # type: ignore[return-type]
|
||||||
|
|
||||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||||
"""Compute doc embeddings using a HuggingFace transformer model.
|
"""Compute doc embeddings using a HuggingFace transformer model.
|
||||||
@ -146,6 +151,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return self._embed(texts, self.encode_kwargs)
|
return self._embed(texts, self.encode_kwargs)
|
||||||
|
|
||||||
@ -157,6 +163,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Embeddings for the text.
|
Embeddings for the text.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
embed_kwargs = (
|
embed_kwargs = (
|
||||||
self.query_encode_kwargs
|
self.query_encode_kwargs
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
@ -27,6 +29,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
|||||||
task="feature-extraction",
|
task="feature-extraction",
|
||||||
huggingfacehub_api_token="my-api-key",
|
huggingfacehub_api_token="my-api-key",
|
||||||
)
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
client: Any = None #: :meta private:
|
client: Any = None #: :meta private:
|
||||||
@ -35,7 +38,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
|||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
provider: Optional[str] = None
|
provider: Optional[str] = None
|
||||||
"""Name of the provider to use for inference with the model specified in
|
"""Name of the provider to use for inference with the model specified in
|
||||||
``repo_id``. e.g. "sambanova". if not specified, defaults to HF Inference API.
|
``repo_id``. e.g. "sambanova". if not specified, defaults to HF Inference API.
|
||||||
available providers can be found in the [huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/guides/inference#supported-providers-and-tasks)."""
|
available providers can be found in the [huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/guides/inference#supported-providers-and-tasks)."""
|
||||||
repo_id: Optional[str] = None
|
repo_id: Optional[str] = None
|
||||||
"""Huggingfacehub repository id, for backward compatibility."""
|
"""Huggingfacehub repository id, for backward compatibility."""
|
||||||
@ -87,18 +90,20 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.task not in VALID_TASKS:
|
if self.task not in VALID_TASKS:
|
||||||
raise ValueError(
|
msg = (
|
||||||
f"Got invalid task {self.task}, "
|
f"Got invalid task {self.task}, "
|
||||||
f"currently only {VALID_TASKS} are supported"
|
f"currently only {VALID_TASKS} are supported"
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
self.client = client
|
self.client = client
|
||||||
self.async_client = async_client
|
self.async_client = async_client
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
msg = (
|
||||||
"Could not import huggingface_hub python package. "
|
"Could not import huggingface_hub python package. "
|
||||||
"Please install it with `pip install huggingface_hub`."
|
"Please install it with `pip install huggingface_hub`."
|
||||||
)
|
)
|
||||||
|
raise ImportError(msg)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||||
@ -109,6 +114,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# replace newlines, which can negatively affect performance.
|
# replace newlines, which can negatively affect performance.
|
||||||
texts = [text.replace("\n", " ") for text in texts]
|
texts = [text.replace("\n", " ") for text in texts]
|
||||||
@ -125,6 +131,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# replace newlines, which can negatively affect performance.
|
# replace newlines, which can negatively affect performance.
|
||||||
texts = [text.replace("\n", " ") for text in texts]
|
texts = [text.replace("\n", " ") for text in texts]
|
||||||
@ -142,9 +149,9 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Embeddings for the text.
|
Embeddings for the text.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
response = self.embed_documents([text])[0]
|
return self.embed_documents([text])[0]
|
||||||
return response
|
|
||||||
|
|
||||||
async def aembed_query(self, text: str) -> list[float]:
|
async def aembed_query(self, text: str) -> list[float]:
|
||||||
"""Async Call to HuggingFaceHub's embedding endpoint for embedding query text.
|
"""Async Call to HuggingFaceHub's embedding endpoint for embedding query text.
|
||||||
@ -154,6 +161,6 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Embeddings for the text.
|
Embeddings for the text.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
response = (await self.aembed_documents([text]))[0]
|
return (await self.aembed_documents([text]))[0]
|
||||||
return response
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@ -25,8 +27,7 @@ VALID_TASKS = (
|
|||||||
|
|
||||||
|
|
||||||
class HuggingFaceEndpoint(LLM):
|
class HuggingFaceEndpoint(LLM):
|
||||||
"""
|
"""Hugging Face Endpoint. This works with any model that supports text generation (i.e. text completion) task.
|
||||||
Hugging Face Endpoint. This works with any model that supports text generation (i.e. text completion) task.
|
|
||||||
|
|
||||||
To use this class, you should have installed the ``huggingface_hub`` package, and
|
To use this class, you should have installed the ``huggingface_hub`` package, and
|
||||||
the environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token,
|
the environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token,
|
||||||
@ -75,6 +76,7 @@ class HuggingFaceEndpoint(LLM):
|
|||||||
huggingfacehub_api_token="my-api-key"
|
huggingfacehub_api_token="my-api-key"
|
||||||
)
|
)
|
||||||
print(llm.invoke("What is Deep Learning?"))
|
print(llm.invoke("What is Deep Learning?"))
|
||||||
|
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
endpoint_url: Optional[str] = None
|
endpoint_url: Optional[str] = None
|
||||||
@ -147,7 +149,8 @@ class HuggingFaceEndpoint(LLM):
|
|||||||
extra = values.get("model_kwargs", {})
|
extra = values.get("model_kwargs", {})
|
||||||
for field_name in list(values):
|
for field_name in list(values):
|
||||||
if field_name in extra:
|
if field_name in extra:
|
||||||
raise ValueError(f"Found {field_name} supplied twice.")
|
msg = f"Found {field_name} supplied twice."
|
||||||
|
raise ValueError(msg)
|
||||||
if field_name not in all_required_field_names:
|
if field_name not in all_required_field_names:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"""WARNING! {field_name} is not default parameter.
|
f"""WARNING! {field_name} is not default parameter.
|
||||||
@ -158,10 +161,11 @@ class HuggingFaceEndpoint(LLM):
|
|||||||
|
|
||||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||||
if invalid_model_kwargs:
|
if invalid_model_kwargs:
|
||||||
raise ValueError(
|
msg = (
|
||||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
values["model_kwargs"] = extra
|
values["model_kwargs"] = extra
|
||||||
|
|
||||||
@ -185,18 +189,20 @@ class HuggingFaceEndpoint(LLM):
|
|||||||
repo_id = values.get("repo_id")
|
repo_id = values.get("repo_id")
|
||||||
|
|
||||||
if sum([bool(model), bool(endpoint_url), bool(repo_id)]) > 1:
|
if sum([bool(model), bool(endpoint_url), bool(repo_id)]) > 1:
|
||||||
raise ValueError(
|
msg = (
|
||||||
"Please specify either a `model` OR an `endpoint_url` OR a `repo_id`,"
|
"Please specify either a `model` OR an `endpoint_url` OR a `repo_id`,"
|
||||||
"not more than one."
|
"not more than one."
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
values["model"] = (
|
values["model"] = (
|
||||||
model or endpoint_url or repo_id or os.environ.get("HF_INFERENCE_ENDPOINT")
|
model or endpoint_url or repo_id or os.environ.get("HF_INFERENCE_ENDPOINT")
|
||||||
)
|
)
|
||||||
if not values["model"]:
|
if not values["model"]:
|
||||||
raise ValueError(
|
msg = (
|
||||||
"Please specify a `model` or an `endpoint_url` or a `repo_id` for the "
|
"Please specify a `model` or an `endpoint_url` or a `repo_id` for the "
|
||||||
"model."
|
"model."
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
@ -274,12 +280,10 @@ class HuggingFaceEndpoint(LLM):
|
|||||||
"""Get the identifying parameters."""
|
"""Get the identifying parameters."""
|
||||||
_model_kwargs = self.model_kwargs or {}
|
_model_kwargs = self.model_kwargs or {}
|
||||||
return {
|
return {
|
||||||
**{
|
"endpoint_url": self.endpoint_url,
|
||||||
"endpoint_url": self.endpoint_url,
|
"task": self.task,
|
||||||
"task": self.task,
|
"provider": self.provider,
|
||||||
"provider": self.provider,
|
"model_kwargs": _model_kwargs,
|
||||||
},
|
|
||||||
**{"model_kwargs": _model_kwargs},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -310,19 +314,19 @@ class HuggingFaceEndpoint(LLM):
|
|||||||
):
|
):
|
||||||
completion += chunk.text
|
completion += chunk.text
|
||||||
return completion
|
return completion
|
||||||
else:
|
|
||||||
response_text = self.client.text_generation(
|
|
||||||
prompt=prompt,
|
|
||||||
model=self.model,
|
|
||||||
**invocation_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Maybe the generation has stopped at one of the stop sequences:
|
response_text = self.client.text_generation(
|
||||||
# then we remove this stop sequence from the end of the generated text
|
prompt=prompt,
|
||||||
for stop_seq in invocation_params["stop"]:
|
model=self.model,
|
||||||
if response_text[-len(stop_seq) :] == stop_seq:
|
**invocation_params,
|
||||||
response_text = response_text[: -len(stop_seq)]
|
)
|
||||||
return response_text
|
|
||||||
|
# Maybe the generation has stopped at one of the stop sequences:
|
||||||
|
# then we remove this stop sequence from the end of the generated text
|
||||||
|
for stop_seq in invocation_params["stop"]:
|
||||||
|
if response_text[-len(stop_seq) :] == stop_seq:
|
||||||
|
response_text = response_text[: -len(stop_seq)]
|
||||||
|
return response_text
|
||||||
|
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
@ -339,20 +343,20 @@ class HuggingFaceEndpoint(LLM):
|
|||||||
):
|
):
|
||||||
completion += chunk.text
|
completion += chunk.text
|
||||||
return completion
|
return completion
|
||||||
else:
|
|
||||||
response_text = await self.async_client.text_generation(
|
|
||||||
prompt=prompt,
|
|
||||||
**invocation_params,
|
|
||||||
model=self.model,
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Maybe the generation has stopped at one of the stop sequences:
|
response_text = await self.async_client.text_generation(
|
||||||
# then remove this stop sequence from the end of the generated text
|
prompt=prompt,
|
||||||
for stop_seq in invocation_params["stop"]:
|
**invocation_params,
|
||||||
if response_text[-len(stop_seq) :] == stop_seq:
|
model=self.model,
|
||||||
response_text = response_text[: -len(stop_seq)]
|
stream=False,
|
||||||
return response_text
|
)
|
||||||
|
|
||||||
|
# Maybe the generation has stopped at one of the stop sequences:
|
||||||
|
# then remove this stop sequence from the end of the generated text
|
||||||
|
for stop_seq in invocation_params["stop"]:
|
||||||
|
if response_text[-len(stop_seq) :] == stop_seq:
|
||||||
|
response_text = response_text[: -len(stop_seq)]
|
||||||
|
return response_text
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
|
@ -10,7 +10,7 @@ from langchain_core.language_models.llms import BaseLLM
|
|||||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||||
from pydantic import ConfigDict, model_validator
|
from pydantic import ConfigDict, model_validator
|
||||||
|
|
||||||
from ..utils.import_utils import (
|
from langchain_huggingface.utils.import_utils import (
|
||||||
IMPORT_ERROR,
|
IMPORT_ERROR,
|
||||||
is_ipex_available,
|
is_ipex_available,
|
||||||
is_openvino_available,
|
is_openvino_available,
|
||||||
@ -87,7 +87,7 @@ class HuggingFacePipeline(BaseLLM):
|
|||||||
def pre_init_validator(cls, values: dict[str, Any]) -> dict[str, Any]:
|
def pre_init_validator(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Ensure model_id is set either by pipeline or user input."""
|
"""Ensure model_id is set either by pipeline or user input."""
|
||||||
if "model_id" not in values:
|
if "model_id" not in values:
|
||||||
if "pipeline" in values and values["pipeline"]:
|
if values.get("pipeline"):
|
||||||
values["model_id"] = values["pipeline"].model.name_or_path
|
values["model_id"] = values["pipeline"].model.name_or_path
|
||||||
else:
|
else:
|
||||||
values["model_id"] = DEFAULT_MODEL_ID
|
values["model_id"] = DEFAULT_MODEL_ID
|
||||||
@ -116,34 +116,38 @@ class HuggingFacePipeline(BaseLLM):
|
|||||||
from transformers import pipeline as hf_pipeline # type: ignore[import]
|
from transformers import pipeline as hf_pipeline # type: ignore[import]
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ValueError(
|
msg = (
|
||||||
"Could not import transformers python package. "
|
"Could not import transformers python package. "
|
||||||
"Please install it with `pip install transformers`."
|
"Please install it with `pip install transformers`."
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
_model_kwargs = model_kwargs.copy() if model_kwargs else {}
|
_model_kwargs = model_kwargs.copy() if model_kwargs else {}
|
||||||
if device_map is not None:
|
if device_map is not None:
|
||||||
if device is not None:
|
if device is not None:
|
||||||
raise ValueError(
|
msg = (
|
||||||
"Both `device` and `device_map` are specified. "
|
"Both `device` and `device_map` are specified. "
|
||||||
"`device` will override `device_map`. "
|
"`device` will override `device_map`. "
|
||||||
"You will most likely encounter unexpected behavior."
|
"You will most likely encounter unexpected behavior."
|
||||||
"Please remove `device` and keep "
|
"Please remove `device` and keep "
|
||||||
"`device_map`."
|
"`device_map`."
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
if "device_map" in _model_kwargs:
|
if "device_map" in _model_kwargs:
|
||||||
raise ValueError("`device_map` is already specified in `model_kwargs`.")
|
msg = "`device_map` is already specified in `model_kwargs`."
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
_model_kwargs["device_map"] = device_map
|
_model_kwargs["device_map"] = device_map
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
|
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
|
||||||
|
|
||||||
if backend in {"openvino", "ipex"}:
|
if backend in {"openvino", "ipex"}:
|
||||||
if task not in VALID_TASKS:
|
if task not in VALID_TASKS:
|
||||||
raise ValueError(
|
msg = (
|
||||||
f"Got invalid task {task}, "
|
f"Got invalid task {task}, "
|
||||||
f"currently only {VALID_TASKS} are supported"
|
f"currently only {VALID_TASKS} are supported"
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
err_msg = f'Backend: {backend} {IMPORT_ERROR.format(f"optimum[{backend}]")}'
|
err_msg = f'Backend: {backend} {IMPORT_ERROR.format(f"optimum[{backend}]")}'
|
||||||
if not is_optimum_intel_available():
|
if not is_optimum_intel_available():
|
||||||
@ -156,12 +160,13 @@ class HuggingFacePipeline(BaseLLM):
|
|||||||
else _MIN_OPTIMUM_VERSION
|
else _MIN_OPTIMUM_VERSION
|
||||||
)
|
)
|
||||||
if is_optimum_intel_version("<", min_optimum_version):
|
if is_optimum_intel_version("<", min_optimum_version):
|
||||||
raise ImportError(
|
msg = (
|
||||||
f"Backend: {backend} requires optimum-intel>="
|
f"Backend: {backend} requires optimum-intel>="
|
||||||
f"{min_optimum_version}. You can install it with pip: "
|
f"{min_optimum_version}. You can install it with pip: "
|
||||||
"`pip install --upgrade --upgrade-strategy eager "
|
"`pip install --upgrade --upgrade-strategy eager "
|
||||||
f"`optimum[{backend}]`."
|
f"`optimum[{backend}]`."
|
||||||
)
|
)
|
||||||
|
raise ImportError(msg)
|
||||||
|
|
||||||
if backend == "openvino":
|
if backend == "openvino":
|
||||||
if not is_openvino_available():
|
if not is_openvino_available():
|
||||||
@ -240,10 +245,11 @@ class HuggingFacePipeline(BaseLLM):
|
|||||||
|
|
||||||
cuda_device_count = torch.cuda.device_count()
|
cuda_device_count = torch.cuda.device_count()
|
||||||
if device < -1 or (device >= cuda_device_count):
|
if device < -1 or (device >= cuda_device_count):
|
||||||
raise ValueError(
|
msg = (
|
||||||
f"Got device=={device}, "
|
f"Got device=={device}, "
|
||||||
f"device is required to be within [-1, {cuda_device_count})"
|
f"device is required to be within [-1, {cuda_device_count})"
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
if device_map is not None and device < 0:
|
if device_map is not None and device < 0:
|
||||||
device = None
|
device = None
|
||||||
if device is not None and device < 0 and cuda_device_count > 0:
|
if device is not None and device < 0 and cuda_device_count > 0:
|
||||||
@ -271,10 +277,11 @@ class HuggingFacePipeline(BaseLLM):
|
|||||||
**_pipeline_kwargs,
|
**_pipeline_kwargs,
|
||||||
)
|
)
|
||||||
if pipeline.task not in VALID_TASKS:
|
if pipeline.task not in VALID_TASKS:
|
||||||
raise ValueError(
|
msg = (
|
||||||
f"Got invalid task {pipeline.task}, "
|
f"Got invalid task {pipeline.task}, "
|
||||||
f"currently only {VALID_TASKS} are supported"
|
f"currently only {VALID_TASKS} are supported"
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
return cls(
|
return cls(
|
||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -324,21 +331,22 @@ class HuggingFacePipeline(BaseLLM):
|
|||||||
# if model returns multiple generations, pick the top one
|
# if model returns multiple generations, pick the top one
|
||||||
response = response[0]
|
response = response[0]
|
||||||
|
|
||||||
if self.pipeline.task == "text-generation":
|
if (
|
||||||
text = response["generated_text"]
|
self.pipeline.task == "text-generation"
|
||||||
elif self.pipeline.task == "text2text-generation":
|
or self.pipeline.task == "text2text-generation"
|
||||||
text = response["generated_text"]
|
or self.pipeline.task == "image-text-to-text"
|
||||||
elif self.pipeline.task == "image-text-to-text":
|
):
|
||||||
text = response["generated_text"]
|
text = response["generated_text"]
|
||||||
elif self.pipeline.task == "summarization":
|
elif self.pipeline.task == "summarization":
|
||||||
text = response["summary_text"]
|
text = response["summary_text"]
|
||||||
elif self.pipeline.task in "translation":
|
elif self.pipeline.task in "translation":
|
||||||
text = response["translation_text"]
|
text = response["translation_text"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
msg = (
|
||||||
f"Got invalid task {self.pipeline.task}, "
|
f"Got invalid task {self.pipeline.task}, "
|
||||||
f"currently only {VALID_TASKS} are supported"
|
f"currently only {VALID_TASKS} are supported"
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
if skip_prompt:
|
if skip_prompt:
|
||||||
text = text[len(batch_prompts[j]) :]
|
text = text[len(batch_prompts[j]) :]
|
||||||
# Append the processed text to results
|
# Append the processed text to results
|
||||||
@ -378,10 +386,7 @@ class HuggingFacePipeline(BaseLLM):
|
|||||||
scores: torch.FloatTensor,
|
scores: torch.FloatTensor,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
for stop_id in stopping_ids_list:
|
return any(input_ids[0][-1] == stop_id for stop_id in stopping_ids_list)
|
||||||
if input_ids[0][-1] == stop_id:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
stopping_criteria = StoppingCriteriaList([StopOnTokens()])
|
stopping_criteria = StoppingCriteriaList([StopOnTokens()])
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import operator as op
|
import operator as op
|
||||||
@ -46,8 +48,7 @@ def compare_versions(
|
|||||||
operation: str,
|
operation: str,
|
||||||
requirement_version: str,
|
requirement_version: str,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""Compare a library version to some requirement using a given operation.
|
||||||
Compare a library version to some requirement using a given operation.
|
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
library_or_version (`str` or `packaging.version.Version`):
|
library_or_version (`str` or `packaging.version.Version`):
|
||||||
@ -56,12 +57,14 @@ def compare_versions(
|
|||||||
A string representation of an operator, such as `">"` or `"<="`.
|
A string representation of an operator, such as `">"` or `"<="`.
|
||||||
requirement_version (`str`):
|
requirement_version (`str`):
|
||||||
The version to compare the library version against
|
The version to compare the library version against
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if operation not in STR_OPERATION_TO_FUNC.keys():
|
if operation not in STR_OPERATION_TO_FUNC:
|
||||||
raise ValueError(
|
msg = (
|
||||||
f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}"
|
f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}"
|
||||||
f", received {operation}"
|
f", received {operation}"
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
if isinstance(library_or_version, str):
|
if isinstance(library_or_version, str):
|
||||||
library_or_version = version.parse(
|
library_or_version = version.parse(
|
||||||
importlib.metadata.version(library_or_version)
|
importlib.metadata.version(library_or_version)
|
||||||
@ -88,9 +91,7 @@ def is_openvino_available() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def is_optimum_version(operation: str, reference_version: str) -> bool:
|
def is_optimum_version(operation: str, reference_version: str) -> bool:
|
||||||
"""
|
"""Compare the current Optimum version to a given reference with an operation."""
|
||||||
Compare the current Optimum version to a given reference with an operation.
|
|
||||||
"""
|
|
||||||
if not _optimum_version:
|
if not _optimum_version:
|
||||||
return False
|
return False
|
||||||
return compare_versions(
|
return compare_versions(
|
||||||
@ -99,9 +100,7 @@ def is_optimum_version(operation: str, reference_version: str) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def is_optimum_intel_version(operation: str, reference_version: str) -> bool:
|
def is_optimum_intel_version(operation: str, reference_version: str) -> bool:
|
||||||
"""
|
"""Compare current Optimum Intel version to a given reference with an operation."""
|
||||||
Compare the current Optimum Intel version to a given reference with an operation.
|
|
||||||
"""
|
|
||||||
if not _optimum_intel_version:
|
if not _optimum_intel_version:
|
||||||
return False
|
return False
|
||||||
return compare_versions(
|
return compare_versions(
|
||||||
|
@ -58,8 +58,60 @@ disallow_untyped_defs = "True"
|
|||||||
target-version = "py39"
|
target-version = "py39"
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = ["E", "F", "I", "T201", "UP", "S"]
|
select = [
|
||||||
ignore = ["UP007"]
|
"A", # flake8-builtins
|
||||||
|
"ASYNC", # flake8-async
|
||||||
|
"C4", # flake8-comprehensions
|
||||||
|
"COM", # flake8-commas
|
||||||
|
"D", # pydocstyle
|
||||||
|
"DOC", # pydoclint
|
||||||
|
"E", # pycodestyle error
|
||||||
|
"EM", # flake8-errmsg
|
||||||
|
"F", # pyflakes
|
||||||
|
"FA", # flake8-future-annotations
|
||||||
|
"FBT", # flake8-boolean-trap
|
||||||
|
"FLY", # flake8-flynt
|
||||||
|
"I", # isort
|
||||||
|
"ICN", # flake8-import-conventions
|
||||||
|
"INT", # flake8-gettext
|
||||||
|
"ISC", # isort-comprehensions
|
||||||
|
"PGH", # pygrep-hooks
|
||||||
|
"PIE", # flake8-pie
|
||||||
|
"PERF", # flake8-perf
|
||||||
|
"PYI", # flake8-pyi
|
||||||
|
"Q", # flake8-quotes
|
||||||
|
"RET", # flake8-return
|
||||||
|
"RSE", # flake8-rst-docstrings
|
||||||
|
"RUF", # ruff
|
||||||
|
"S", # flake8-bandit
|
||||||
|
"SLF", # flake8-self
|
||||||
|
"SLOT", # flake8-slots
|
||||||
|
"SIM", # flake8-simplify
|
||||||
|
"T10", # flake8-debugger
|
||||||
|
"T20", # flake8-print
|
||||||
|
"TID", # flake8-tidy-imports
|
||||||
|
"UP", # pyupgrade
|
||||||
|
"W", # pycodestyle warning
|
||||||
|
"YTT", # flake8-2020
|
||||||
|
]
|
||||||
|
ignore = [
|
||||||
|
"D100", # pydocstyle: Missing docstring in public module
|
||||||
|
"D101", # pydocstyle: Missing docstring in public class
|
||||||
|
"D102", # pydocstyle: Missing docstring in public method
|
||||||
|
"D103", # pydocstyle: Missing docstring in public function
|
||||||
|
"D104", # pydocstyle: Missing docstring in public package
|
||||||
|
"D105", # pydocstyle: Missing docstring in magic method
|
||||||
|
"D107", # pydocstyle: Missing docstring in __init__
|
||||||
|
"D407", # pydocstyle: Missing-dashed-underline-after-section
|
||||||
|
"COM812", # Messes with the formatter
|
||||||
|
"ISC001", # Messes with the formatter
|
||||||
|
"PERF203", # Rarely useful
|
||||||
|
"S112", # Rarely useful
|
||||||
|
"RUF012", # Doesn't play well with Pydantic
|
||||||
|
"SLF001", # Private member access
|
||||||
|
"UP007", # pyupgrade: non-pep604-annotation-union
|
||||||
|
|
||||||
|
]
|
||||||
|
|
||||||
[tool.coverage.run]
|
[tool.coverage.run]
|
||||||
omit = ["tests/*"]
|
omit = ["tests/*"]
|
||||||
|
@ -4,4 +4,3 @@ import pytest # type: ignore[import-not-found, import-not-found]
|
|||||||
@pytest.mark.compile
|
@pytest.mark.compile
|
||||||
def test_placeholder() -> None:
|
def test_placeholder() -> None:
|
||||||
"""Used for compiling integration tests without running any real tests."""
|
"""Used for compiling integration tests without running any real tests."""
|
||||||
pass
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
"""Standard LangChain interface tests"""
|
"""Standard LangChain interface tests."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
@ -31,8 +31,7 @@ def mock_llm() -> Mock:
|
|||||||
"langchain_huggingface.chat_models.huggingface.ChatHuggingFace._resolve_model_id"
|
"langchain_huggingface.chat_models.huggingface.ChatHuggingFace._resolve_model_id"
|
||||||
)
|
)
|
||||||
def chat_hugging_face(mock_resolve_id: Any, mock_llm: Any) -> ChatHuggingFace:
|
def chat_hugging_face(mock_resolve_id: Any, mock_llm: Any) -> ChatHuggingFace:
|
||||||
chat_hf = ChatHuggingFace(llm=mock_llm, tokenizer=MagicMock())
|
return ChatHuggingFace(llm=mock_llm, tokenizer=MagicMock())
|
||||||
return chat_hf
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_chat_result(chat_hugging_face: Any) -> None:
|
def test_create_chat_result(chat_hugging_face: Any) -> None:
|
||||||
|
@ -6,8 +6,7 @@ DEFAULT_MODEL_ID = "gpt2"
|
|||||||
|
|
||||||
|
|
||||||
def test_initialization_default() -> None:
|
def test_initialization_default() -> None:
|
||||||
"""Test default initialization"""
|
"""Test default initialization."""
|
||||||
|
|
||||||
llm = HuggingFacePipeline()
|
llm = HuggingFacePipeline()
|
||||||
|
|
||||||
assert llm.model_id == DEFAULT_MODEL_ID
|
assert llm.model_id == DEFAULT_MODEL_ID
|
||||||
@ -15,8 +14,7 @@ def test_initialization_default() -> None:
|
|||||||
|
|
||||||
@patch("transformers.pipeline")
|
@patch("transformers.pipeline")
|
||||||
def test_initialization_with_pipeline(mock_pipeline: MagicMock) -> None:
|
def test_initialization_with_pipeline(mock_pipeline: MagicMock) -> None:
|
||||||
"""Test initialization with a pipeline object"""
|
"""Test initialization with a pipeline object."""
|
||||||
|
|
||||||
mock_pipe = MagicMock()
|
mock_pipe = MagicMock()
|
||||||
mock_pipe.model.name_or_path = "mock-model-id"
|
mock_pipe.model.name_or_path = "mock-model-id"
|
||||||
mock_pipeline.return_value = mock_pipe
|
mock_pipeline.return_value = mock_pipe
|
||||||
@ -32,8 +30,7 @@ def test_initialization_with_pipeline(mock_pipeline: MagicMock) -> None:
|
|||||||
def test_initialization_with_from_model_id(
|
def test_initialization_with_from_model_id(
|
||||||
mock_pipeline: MagicMock, mock_model: MagicMock, mock_tokenizer: MagicMock
|
mock_pipeline: MagicMock, mock_model: MagicMock, mock_tokenizer: MagicMock
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test initialization with the from_model_id method"""
|
"""Test initialization with the from_model_id method."""
|
||||||
|
|
||||||
mock_tokenizer.return_value = MagicMock(pad_token_id=0)
|
mock_tokenizer.return_value = MagicMock(pad_token_id=0)
|
||||||
mock_model.return_value = MagicMock()
|
mock_model.return_value = MagicMock()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user