partners[lint]: run pyupgrade to get code in line with 3.9 standards (#30781)

Using `pyupgrade` to get all `partners` code up to 3.9 standards
(mostly, fixing old `typing` imports).
This commit is contained in:
Sydney Runkle
2025-04-11 07:18:44 -04:00
committed by GitHub
parent e72f3c26a0
commit 8c6734325b
123 changed files with 1000 additions and 1109 deletions

View File

@@ -1,16 +1,13 @@
"""Hugging Face Chat Wrapper."""
import json
from collections.abc import Sequence
from dataclasses import dataclass
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Type,
Union,
cast,
)
@@ -46,8 +43,8 @@ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."
class TGI_RESPONSE:
"""Response from the TextGenInference API."""
choices: List[Any]
usage: Dict
choices: list[Any]
usage: dict
@dataclass
@@ -56,12 +53,12 @@ class TGI_MESSAGE:
role: str
content: str
tool_calls: List[Dict]
tool_calls: list[dict]
def _convert_message_to_chat_message(
message: BaseMessage,
) -> Dict:
) -> dict:
if isinstance(message, ChatMessage):
return dict(role=message.role, content=message.content)
elif isinstance(message, HumanMessage):
@@ -104,7 +101,7 @@ def _convert_TGI_message_to_LC_message(
content = cast(str, _message.content)
if content is None:
content = ""
additional_kwargs: Dict = {}
additional_kwargs: dict = {}
if tool_calls := _message.tool_calls:
if "arguments" in tool_calls[0]["function"]:
functions = tool_calls[0]["function"].pop("arguments")
@@ -358,8 +355,8 @@ class ChatHuggingFace(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
@@ -380,8 +377,8 @@ class ChatHuggingFace(BaseChatModel):
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
@@ -398,7 +395,7 @@ class ChatHuggingFace(BaseChatModel):
def _to_chat_prompt(
self,
messages: List[BaseMessage],
messages: list[BaseMessage],
) -> str:
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
if not messages:
@@ -472,7 +469,7 @@ class ChatHuggingFace(BaseChatModel):
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
*,
tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required"], bool]
@@ -529,8 +526,8 @@ class ChatHuggingFace(BaseChatModel):
return super().bind(tools=formatted_tools, **kwargs)
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> List[Dict[Any, Any]]:
self, messages: list[BaseMessage], stop: Optional[list[str]]
) -> list[dict[Any, Any]]:
message_dicts = [_convert_message_to_chat_message(m) for m in messages]
return message_dicts

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from langchain_core.embeddings import Embeddings
from pydantic import BaseModel, ConfigDict, Field
@@ -40,16 +40,16 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
cache_folder: Optional[str] = None
"""Path to store models.
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass to the Sentence Transformer model, such as `device`,
`prompts`, `default_prompt_name`, `revision`, `trust_remote_code`, or `token`.
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer"""
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
encode_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass when calling the `encode` method for the documents of
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
`precision`, `normalize_embeddings`, and more.
See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""
query_encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
query_encode_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass when calling the `encode` method for the query of
the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,
`precision`, `normalize_embeddings`, and more.
@@ -102,8 +102,8 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
)
def _embed(
self, texts: list[str], encode_kwargs: Dict[str, Any]
) -> List[List[float]]:
self, texts: list[str], encode_kwargs: dict[str, Any]
) -> list[list[float]]:
"""
Embed a text using the HuggingFace transformer model.
@@ -138,7 +138,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
return embeddings.tolist()
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Compute doc embeddings using a HuggingFace transformer model.
Args:
@@ -149,7 +149,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
"""
return self._embed(texts, self.encode_kwargs)
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> list[float]:
"""Compute query embeddings using a HuggingFace transformer model.
Args:

View File

@@ -1,5 +1,5 @@
import os
from typing import Any, List, Optional
from typing import Any, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.utils import from_env
@@ -101,7 +101,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
)
return self
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Call out to HuggingFaceHub's embedding endpoint for embedding search docs.
Args:
@@ -117,7 +117,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
responses = self.client.feature_extraction(text=texts, **_model_kwargs)
return responses
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Async Call to HuggingFaceHub's embedding endpoint for embedding search docs.
Args:
@@ -134,7 +134,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
)
return responses
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> list[float]:
"""Call out to HuggingFaceHub's embedding endpoint for embedding query text.
Args:
@@ -146,7 +146,7 @@ class HuggingFaceEndpointEmbeddings(BaseModel, Embeddings):
response = self.embed_documents([text])[0]
return response
async def aembed_query(self, text: str) -> List[float]:
async def aembed_query(self, text: str) -> list[float]:
"""Async Call to HuggingFaceHub's embedding endpoint for embedding query text.
Args:

View File

@@ -2,7 +2,8 @@ import inspect
import json # type: ignore[import-not-found]
import logging
import os
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional
from collections.abc import AsyncIterator, Iterator, Mapping
from typing import Any, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
@@ -96,7 +97,7 @@ class HuggingFaceEndpoint(LLM):
"""Whether to prepend the prompt to the generated text"""
truncate: Optional[int] = None
"""Truncate inputs tokens to the given size"""
stop_sequences: List[str] = Field(default_factory=list)
stop_sequences: list[str] = Field(default_factory=list)
"""Stop generating tokens if a member of `stop_sequences` is generated"""
seed: Optional[int] = None
"""Random sampling seed"""
@@ -111,9 +112,9 @@ class HuggingFaceEndpoint(LLM):
watermark: bool = False
"""Watermarking with [A Watermark for Large Language Models]
(https://arxiv.org/abs/2301.10226)"""
server_kwargs: Dict[str, Any] = Field(default_factory=dict)
server_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any text-generation-inference server parameters not explicitly specified"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `call` not explicitly specified"""
model: str
client: Any = None #: :meta private:
@@ -128,7 +129,7 @@ class HuggingFaceEndpoint(LLM):
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
@@ -252,7 +253,7 @@ class HuggingFaceEndpoint(LLM):
return self
@property
def _default_params(self) -> Dict[str, Any]:
def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling text generation inference API."""
return {
"max_new_tokens": self.max_new_tokens,
@@ -285,8 +286,8 @@ class HuggingFaceEndpoint(LLM):
return "huggingface_endpoint"
def _invocation_params(
self, runtime_stop: Optional[List[str]], **kwargs: Any
) -> Dict[str, Any]:
self, runtime_stop: Optional[list[str]], **kwargs: Any
) -> dict[str, Any]:
params = {**self._default_params, **kwargs}
params["stop_sequences"] = params["stop_sequences"] + (runtime_stop or [])
return params
@@ -294,7 +295,7 @@ class HuggingFaceEndpoint(LLM):
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
@@ -326,7 +327,7 @@ class HuggingFaceEndpoint(LLM):
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
@@ -357,7 +358,7 @@ class HuggingFaceEndpoint(LLM):
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
@@ -394,7 +395,7 @@ class HuggingFaceEndpoint(LLM):
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:

View File

@@ -2,7 +2,8 @@ from __future__ import annotations # type: ignore[import-not-found]
import importlib.util
import logging
from typing import Any, Dict, Iterator, List, Mapping, Optional
from collections.abc import Iterator, Mapping
from typing import Any, Optional
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import BaseLLM
@@ -82,7 +83,7 @@ class HuggingFacePipeline(BaseLLM):
@model_validator(mode="before")
@classmethod
def pre_init_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]:
def pre_init_validator(cls, values: dict[str, Any]) -> dict[str, Any]:
"""Ensure model_id is set either by pipeline or user input."""
if "model_id" not in values:
if "pipeline" in values and values["pipeline"]:
@@ -297,13 +298,13 @@ class HuggingFacePipeline(BaseLLM):
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
prompts: list[str],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
# List to hold all results
text_generations: List[str] = []
text_generations: list[str] = []
pipeline_kwargs = kwargs.get("pipeline_kwargs", {})
skip_prompt = kwargs.get("skip_prompt", False)
@@ -347,7 +348,7 @@ class HuggingFacePipeline(BaseLLM):
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]: