mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
@@ -7,7 +7,7 @@ import json
|
||||
from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from operator import itemgetter
|
||||
from typing import Any, Literal, Optional, Union, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@@ -316,7 +316,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
the HuggingFace Hub.
|
||||
|
||||
Setup:
|
||||
Install ``langchain-huggingface`` and ensure your Hugging Face token
|
||||
Install `langchain-huggingface` and ensure your Hugging Face token
|
||||
is saved.
|
||||
|
||||
.. code-block:: bash
|
||||
@@ -478,33 +478,33 @@ class ChatHuggingFace(BaseChatModel):
|
||||
HuggingFaceHub, or HuggingFacePipeline."""
|
||||
tokenizer: Any = None
|
||||
"""Tokenizer for the model. Only used for HuggingFacePipeline."""
|
||||
model_id: Optional[str] = None
|
||||
model_id: str | None = None
|
||||
"""Model ID for the model. Only used for HuggingFaceEndpoint."""
|
||||
temperature: Optional[float] = None
|
||||
temperature: float | None = None
|
||||
"""What sampling temperature to use."""
|
||||
stop: Optional[Union[str, list[str]]] = Field(default=None, alias="stop_sequences")
|
||||
stop: str | list[str] | None = Field(default=None, alias="stop_sequences")
|
||||
"""Default stop sequences."""
|
||||
presence_penalty: Optional[float] = None
|
||||
presence_penalty: float | None = None
|
||||
"""Penalizes repeated tokens."""
|
||||
frequency_penalty: Optional[float] = None
|
||||
frequency_penalty: float | None = None
|
||||
"""Penalizes repeated tokens according to frequency."""
|
||||
seed: Optional[int] = None
|
||||
seed: int | None = None
|
||||
"""Seed for generation"""
|
||||
logprobs: Optional[bool] = None
|
||||
logprobs: bool | None = None
|
||||
"""Whether to return logprobs."""
|
||||
top_logprobs: Optional[int] = None
|
||||
top_logprobs: int | None = None
|
||||
"""Number of most likely tokens to return at each token position, each with
|
||||
an associated log probability. `logprobs` must be set to true
|
||||
if this parameter is used."""
|
||||
logit_bias: Optional[dict[int, int]] = None
|
||||
logit_bias: dict[int, int] | None = None
|
||||
"""Modify the likelihood of specified tokens appearing in the completion."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
n: Optional[int] = None
|
||||
n: int | None = None
|
||||
"""Number of chat completions to generate for each prompt."""
|
||||
top_p: Optional[float] = None
|
||||
top_p: float | None = None
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
max_tokens: Optional[int] = None
|
||||
max_tokens: int | None = None
|
||||
"""Maximum number of tokens to generate."""
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
@@ -558,9 +558,9 @@ class ChatHuggingFace(BaseChatModel):
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None, # noqa: FBT001
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
stream: bool | None = None, # noqa: FBT001
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
@@ -599,9 +599,9 @@ class ChatHuggingFace(BaseChatModel):
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None, # noqa: FBT001
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
stream: bool | None = None, # noqa: FBT001
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if _is_huggingface_textgen_inference(self.llm):
|
||||
@@ -638,8 +638,8 @@ class ChatHuggingFace(BaseChatModel):
|
||||
def _stream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
if _is_huggingface_endpoint(self.llm):
|
||||
@@ -687,8 +687,8 @@ class ChatHuggingFace(BaseChatModel):
|
||||
async def _astream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
@@ -779,7 +779,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
self.model_id = self.llm.repo_id
|
||||
return
|
||||
if _is_huggingface_textgen_inference(self.llm):
|
||||
endpoint_url: Optional[str] = self.llm.inference_server_url
|
||||
endpoint_url: str | None = self.llm.inference_server_url
|
||||
if _is_huggingface_pipeline(self.llm):
|
||||
from transformers import AutoTokenizer # type: ignore[import]
|
||||
|
||||
@@ -809,11 +809,9 @@ class ChatHuggingFace(BaseChatModel):
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
|
||||
tools: Sequence[dict[str, Any] | type | Callable | BaseTool],
|
||||
*,
|
||||
tool_choice: Optional[
|
||||
Union[dict, str, Literal["auto", "none", "required"], bool] # noqa: PYI051
|
||||
] = None,
|
||||
tool_choice: dict | str | bool | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, AIMessage]:
|
||||
"""Bind tool-like objects to this chat model.
|
||||
@@ -826,7 +824,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
`langchain_core.utils.function_calling.convert_to_openai_tool`.
|
||||
tool_choice: Which tool to require the model to call.
|
||||
Must be the name of the single provided function or
|
||||
``'auto'`` to automatically determine which function to call
|
||||
`'auto'` to automatically determine which function to call
|
||||
(if any), or a dict of the form:
|
||||
{"type": "function", "function": {"name": <<tool_name>>}}.
|
||||
**kwargs: Any additional parameters to pass to the
|
||||
@@ -870,14 +868,14 @@ class ChatHuggingFace(BaseChatModel):
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Optional[Union[dict, type[BaseModel]]] = None,
|
||||
schema: dict | type[BaseModel] | None = None,
|
||||
*,
|
||||
method: Literal[
|
||||
"function_calling", "json_mode", "json_schema"
|
||||
] = "function_calling",
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
|
||||
) -> Runnable[LanguageModelInput, dict | BaseModel]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
@@ -948,7 +946,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
if is_pydantic_schema:
|
||||
msg = "Pydantic schema is not supported for function calling"
|
||||
raise NotImplementedError(msg)
|
||||
output_parser: Union[JsonOutputKeyToolsParser, JsonOutputParser] = (
|
||||
output_parser: JsonOutputKeyToolsParser | JsonOutputParser = (
|
||||
JsonOutputKeyToolsParser(key_name=tool_name, first_tool_only=True)
|
||||
)
|
||||
elif method == "json_schema":
|
||||
@@ -966,9 +964,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
"schema": schema,
|
||||
},
|
||||
)
|
||||
output_parser: Union[ # type: ignore[no-redef]
|
||||
JsonOutputKeyToolsParser, JsonOutputParser
|
||||
] = JsonOutputParser() # type: ignore[arg-type]
|
||||
output_parser = JsonOutputParser() # type: ignore[arg-type]
|
||||
elif method == "json_mode":
|
||||
llm = self.bind(
|
||||
response_format={"type": "json_object"},
|
||||
@@ -977,9 +973,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
"schema": schema,
|
||||
},
|
||||
)
|
||||
output_parser: Union[ # type: ignore[no-redef]
|
||||
JsonOutputKeyToolsParser, JsonOutputParser
|
||||
] = JsonOutputParser() # type: ignore[arg-type]
|
||||
output_parser = JsonOutputParser() # type: ignore[arg-type]
|
||||
else:
|
||||
msg = (
|
||||
f"Unrecognized method argument. Expected one of 'function_calling' or "
|
||||
@@ -999,7 +993,7 @@ class ChatHuggingFace(BaseChatModel):
|
||||
return llm | output_parser
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: list[BaseMessage], stop: Optional[list[str]]
|
||||
self, messages: list[BaseMessage], stop: list[str] | None
|
||||
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||
params = self._default_params
|
||||
if stop is not None:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
@@ -40,7 +40,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
model_name: str = Field(default=DEFAULT_MODEL_NAME, alias="model")
|
||||
"""Model name to use."""
|
||||
cache_folder: Optional[str] = None
|
||||
cache_folder: str | None = None
|
||||
"""Path to store models.
|
||||
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.utils import from_env
|
||||
@@ -35,20 +35,20 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
client: Any = None #: :meta private:
|
||||
async_client: Any = None #: :meta private:
|
||||
model: Optional[str] = None
|
||||
model: str | None = None
|
||||
"""Model name to use."""
|
||||
provider: Optional[str] = None
|
||||
provider: str | None = None
|
||||
"""Name of the provider to use for inference with the model specified in
|
||||
``repo_id``. e.g. "sambanova". if not specified, defaults to HF Inference API.
|
||||
available providers can be found in the [huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/guides/inference#supported-providers-and-tasks)."""
|
||||
repo_id: Optional[str] = None
|
||||
repo_id: str | None = None
|
||||
"""Huggingfacehub repository id, for backward compatibility."""
|
||||
task: Optional[str] = "feature-extraction"
|
||||
task: str | None = "feature-extraction"
|
||||
"""Task to call the model with."""
|
||||
model_kwargs: Optional[dict] = None
|
||||
model_kwargs: dict | None = None
|
||||
"""Keyword arguments to pass to the model."""
|
||||
|
||||
huggingfacehub_api_token: Optional[str] = Field(
|
||||
huggingfacehub_api_token: str | None = Field(
|
||||
default_factory=from_env("HUGGINGFACEHUB_API_TOKEN", default=None)
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import inspect
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncIterator, Iterator, Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@@ -79,42 +79,42 @@ class HuggingFaceEndpoint(LLM):
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
endpoint_url: Optional[str] = None
|
||||
endpoint_url: str | None = None
|
||||
"""Endpoint URL to use. If repo_id is not specified then this needs to given or
|
||||
should be pass as env variable in `HF_INFERENCE_ENDPOINT`"""
|
||||
repo_id: Optional[str] = None
|
||||
repo_id: str | None = None
|
||||
"""Repo to use. If endpoint_url is not specified then this needs to given"""
|
||||
provider: Optional[str] = None
|
||||
provider: str | None = None
|
||||
"""Name of the provider to use for inference with the model specified in `repo_id`.
|
||||
e.g. "cerebras". if not specified, Defaults to "auto" i.e. the first of the
|
||||
providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
|
||||
available providers can be found in the [huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/guides/inference#supported-providers-and-tasks)."""
|
||||
huggingfacehub_api_token: Optional[str] = Field(
|
||||
huggingfacehub_api_token: str | None = Field(
|
||||
default_factory=from_env("HUGGINGFACEHUB_API_TOKEN", default=None)
|
||||
)
|
||||
max_new_tokens: int = 512
|
||||
"""Maximum number of generated tokens"""
|
||||
top_k: Optional[int] = None
|
||||
top_k: int | None = None
|
||||
"""The number of highest probability vocabulary tokens to keep for
|
||||
top-k-filtering."""
|
||||
top_p: Optional[float] = 0.95
|
||||
top_p: float | None = 0.95
|
||||
"""If set to < 1, only the smallest set of most probable tokens with probabilities
|
||||
that add up to `top_p` or higher are kept for generation."""
|
||||
typical_p: Optional[float] = 0.95
|
||||
typical_p: float | None = 0.95
|
||||
"""Typical Decoding mass. See [Typical Decoding for Natural Language
|
||||
Generation](https://arxiv.org/abs/2202.00666) for more information."""
|
||||
temperature: Optional[float] = 0.8
|
||||
temperature: float | None = 0.8
|
||||
"""The value used to module the logits distribution."""
|
||||
repetition_penalty: Optional[float] = None
|
||||
repetition_penalty: float | None = None
|
||||
"""The parameter for repetition penalty. 1.0 means no penalty.
|
||||
See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details."""
|
||||
return_full_text: bool = False
|
||||
"""Whether to prepend the prompt to the generated text"""
|
||||
truncate: Optional[int] = None
|
||||
truncate: int | None = None
|
||||
"""Truncate inputs tokens to the given size"""
|
||||
stop_sequences: list[str] = Field(default_factory=list)
|
||||
"""Stop generating tokens if a member of `stop_sequences` is generated"""
|
||||
seed: Optional[int] = None
|
||||
seed: int | None = None
|
||||
"""Random sampling seed"""
|
||||
inference_server_url: str = ""
|
||||
"""text-generation-inference instance base url"""
|
||||
@@ -134,7 +134,7 @@ class HuggingFaceEndpoint(LLM):
|
||||
model: str
|
||||
client: Any = None #: :meta private:
|
||||
async_client: Any = None #: :meta private:
|
||||
task: Optional[str] = None
|
||||
task: str | None = None
|
||||
"""Task to call the model with. Should be a task that returns `generated_text`."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
@@ -292,7 +292,7 @@ class HuggingFaceEndpoint(LLM):
|
||||
return "huggingface_endpoint"
|
||||
|
||||
def _invocation_params(
|
||||
self, runtime_stop: Optional[list[str]], **kwargs: Any
|
||||
self, runtime_stop: list[str] | None, **kwargs: Any
|
||||
) -> dict[str, Any]:
|
||||
params = {**self._default_params, **kwargs}
|
||||
params["stop"] = params["stop"] + (runtime_stop or [])
|
||||
@@ -301,8 +301,8 @@ class HuggingFaceEndpoint(LLM):
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to HuggingFace Hub's inference endpoint."""
|
||||
@@ -331,8 +331,8 @@ class HuggingFaceEndpoint(LLM):
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
invocation_params = self._invocation_params(stop, **kwargs)
|
||||
@@ -361,8 +361,8 @@ class HuggingFaceEndpoint(LLM):
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
invocation_params = self._invocation_params(stop, **kwargs)
|
||||
@@ -371,13 +371,13 @@ class HuggingFaceEndpoint(LLM):
|
||||
prompt, **invocation_params, stream=True
|
||||
):
|
||||
# identify stop sequence in generated text, if any
|
||||
stop_seq_found: Optional[str] = None
|
||||
stop_seq_found: str | None = None
|
||||
for stop_seq in invocation_params["stop"]:
|
||||
if stop_seq in response:
|
||||
stop_seq_found = stop_seq
|
||||
|
||||
# identify text to yield
|
||||
text: Optional[str] = None
|
||||
text: str | None = None
|
||||
if stop_seq_found:
|
||||
text = response[: response.index(stop_seq_found)]
|
||||
else:
|
||||
@@ -398,8 +398,8 @@ class HuggingFaceEndpoint(LLM):
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
invocation_params = self._invocation_params(stop, **kwargs)
|
||||
@@ -407,13 +407,13 @@ class HuggingFaceEndpoint(LLM):
|
||||
prompt, **invocation_params, stream=True
|
||||
):
|
||||
# identify stop sequence in generated text, if any
|
||||
stop_seq_found: Optional[str] = None
|
||||
stop_seq_found: str | None = None
|
||||
for stop_seq in invocation_params["stop"]:
|
||||
if stop_seq in response:
|
||||
stop_seq_found = stop_seq
|
||||
|
||||
# identify text to yield
|
||||
text: Optional[str] = None
|
||||
text: str | None = None
|
||||
if stop_seq_found:
|
||||
text = response[: response.index(stop_seq_found)]
|
||||
else:
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations # type: ignore[import-not-found]
|
||||
import importlib.util
|
||||
import logging
|
||||
from collections.abc import Iterator, Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import BaseLLM
|
||||
@@ -71,13 +71,13 @@ class HuggingFacePipeline(BaseLLM):
|
||||
"""
|
||||
|
||||
pipeline: Any = None #: :meta private:
|
||||
model_id: Optional[str] = None
|
||||
model_id: str | None = None
|
||||
"""The model name. If not set explicitly by the user,
|
||||
it will be inferred from the provided pipeline (if available).
|
||||
If neither is provided, the DEFAULT_MODEL_ID will be used."""
|
||||
model_kwargs: Optional[dict] = None
|
||||
model_kwargs: dict | None = None
|
||||
"""Keyword arguments passed to the model."""
|
||||
pipeline_kwargs: Optional[dict] = None
|
||||
pipeline_kwargs: dict | None = None
|
||||
"""Keyword arguments passed to the pipeline."""
|
||||
batch_size: int = DEFAULT_BATCH_SIZE
|
||||
"""Batch size to use when passing multiple documents to generate."""
|
||||
@@ -103,10 +103,10 @@ class HuggingFacePipeline(BaseLLM):
|
||||
model_id: str,
|
||||
task: str,
|
||||
backend: str = "default",
|
||||
device: Optional[int] = None,
|
||||
device_map: Optional[str] = None,
|
||||
model_kwargs: Optional[dict] = None,
|
||||
pipeline_kwargs: Optional[dict] = None,
|
||||
device: int | None = None,
|
||||
device_map: str | None = None,
|
||||
model_kwargs: dict | None = None,
|
||||
pipeline_kwargs: dict | None = None,
|
||||
batch_size: int = DEFAULT_BATCH_SIZE,
|
||||
**kwargs: Any,
|
||||
) -> HuggingFacePipeline:
|
||||
@@ -311,8 +311,8 @@ class HuggingFacePipeline(BaseLLM):
|
||||
def _generate(
|
||||
self,
|
||||
prompts: list[str],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
# List to hold all results
|
||||
@@ -363,8 +363,8 @@ class HuggingFacePipeline(BaseLLM):
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
from threading import Thread
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
import operator as op
|
||||
from typing import Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
@@ -44,7 +43,7 @@ _openvino_available = importlib.util.find_spec("openvino") is not None
|
||||
|
||||
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
|
||||
def compare_versions(
|
||||
library_or_version: Union[str, version.Version],
|
||||
library_or_version: str | version.Version,
|
||||
operation: str,
|
||||
requirement_version: str,
|
||||
) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user