partners[lint]: run pyupgrade to get code in line with 3.9 standards (#30781)

Using `pyupgrade` to get all `partners` code up to 3.9 standards
(mostly, fixing old `typing` imports).
This commit is contained in:
Sydney Runkle
2025-04-11 07:18:44 -04:00
committed by GitHub
parent e72f3c26a0
commit 8c6734325b
123 changed files with 1000 additions and 1109 deletions

View File

@@ -7,20 +7,14 @@ import os
import re
import ssl
import uuid
from collections.abc import AsyncIterator, Iterator, Sequence
from contextlib import AbstractAsyncContextManager
from operator import itemgetter
from typing import (
Any,
AsyncContextManager,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
@@ -142,13 +136,13 @@ def _convert_tool_call_id_to_mistral_compatible(tool_call_id: str) -> str:
def _convert_mistral_chat_message_to_message(
_message: Dict,
_message: dict,
) -> BaseMessage:
role = _message["role"]
assert role == "assistant", f"Expected role to be 'assistant', got {role}"
content = cast(str, _message["content"])
additional_kwargs: Dict = {}
additional_kwargs: dict = {}
tool_calls = []
invalid_tool_calls = []
if raw_tool_calls := _message.get("tool_calls"):
@@ -196,8 +190,8 @@ async def _araise_on_error(response: httpx.Response) -> None:
async def _aiter_sse(
event_source_mgr: AsyncContextManager[EventSource],
) -> AsyncIterator[Dict]:
event_source_mgr: AbstractAsyncContextManager[EventSource],
) -> AsyncIterator[dict]:
"""Iterate over the server-sent events."""
async with event_source_mgr as event_source:
await _araise_on_error(event_source.response)
@@ -234,7 +228,7 @@ async def acompletion_with_retry(
def _convert_chunk_to_message_chunk(
chunk: Dict, default_class: Type[BaseMessageChunk]
chunk: dict, default_class: type[BaseMessageChunk]
) -> BaseMessageChunk:
_choice = chunk["choices"][0]
_delta = _choice["delta"]
@@ -243,7 +237,7 @@ def _convert_chunk_to_message_chunk(
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
additional_kwargs: Dict = {}
additional_kwargs: dict = {}
response_metadata = {}
if raw_tool_calls := _delta.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
@@ -295,7 +289,7 @@ def _convert_chunk_to_message_chunk(
def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
"""Format Langchain ToolCall to dict expected by Mistral."""
result: Dict[str, Any] = {
result: dict[str, Any] = {
"function": {
"name": tool_call["name"],
"arguments": json.dumps(tool_call["args"]),
@@ -309,7 +303,7 @@ def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) -> dict:
"""Format Langchain InvalidToolCall to dict expected by Mistral."""
result: Dict[str, Any] = {
result: dict[str, Any] = {
"function": {
"name": invalid_tool_call["name"],
"arguments": invalid_tool_call["args"],
@@ -323,13 +317,13 @@ def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) ->
def _convert_message_to_mistral_chat_message(
message: BaseMessage,
) -> Dict:
) -> dict:
if isinstance(message, ChatMessage):
return dict(role=message.role, content=message.content)
elif isinstance(message, HumanMessage):
return dict(role="user", content=message.content)
elif isinstance(message, AIMessage):
message_dict: Dict[str, Any] = {"role": "assistant"}
message_dict: dict[str, Any] = {"role": "assistant"}
tool_calls = []
if message.tool_calls or message.invalid_tool_calls:
for tool_call in message.tool_calls:
@@ -407,7 +401,7 @@ class ChatMistralAI(BaseChatModel):
random_seed: Optional[int] = None
safe_mode: Optional[bool] = None
streaming: bool = False
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any invocation parameters not explicitly specified."""
model_config = ConfigDict(
@@ -417,14 +411,14 @@ class ChatMistralAI(BaseChatModel):
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
values = _build_model_kwargs(values, all_required_field_names)
return values
@property
def _default_params(self) -> Dict[str, Any]:
def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling the API."""
defaults = {
"model": self.model,
@@ -439,7 +433,7 @@ class ChatMistralAI(BaseChatModel):
return filtered
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
self, stop: Optional[list[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs)
@@ -456,7 +450,7 @@ class ChatMistralAI(BaseChatModel):
return ls_params
@property
def _client_params(self) -> Dict[str, Any]:
def _client_params(self) -> dict[str, Any]:
"""Get the parameters used for the client."""
return self._default_params
@@ -473,7 +467,7 @@ class ChatMistralAI(BaseChatModel):
stream = kwargs["stream"]
if stream:
def iter_sse() -> Iterator[Dict]:
def iter_sse() -> Iterator[dict]:
with connect_sse(
self.client, "POST", "/chat/completions", json=kwargs
) as event_source:
@@ -492,7 +486,7 @@ class ChatMistralAI(BaseChatModel):
rtn = _completion_with_retry(**kwargs)
return rtn
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
for output in llm_outputs:
if output is None:
@@ -557,8 +551,8 @@ class ChatMistralAI(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
@@ -577,7 +571,7 @@ class ChatMistralAI(BaseChatModel):
)
return self._create_chat_result(response)
def _create_chat_result(self, response: Dict) -> ChatResult:
def _create_chat_result(self, response: dict) -> ChatResult:
generations = []
token_usage = response.get("usage", {})
for res in response["choices"]:
@@ -603,8 +597,8 @@ class ChatMistralAI(BaseChatModel):
return ChatResult(generations=generations, llm_output=llm_output)
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict], Dict[str, Any]]:
self, messages: list[BaseMessage], stop: Optional[list[str]]
) -> tuple[list[dict], dict[str, Any]]:
params = self._client_params
if stop is not None or "stop" in params:
if "stop" in params:
@@ -617,15 +611,15 @@ class ChatMistralAI(BaseChatModel):
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
for chunk in self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
):
@@ -643,15 +637,15 @@ class ChatMistralAI(BaseChatModel):
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
async for chunk in await acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params
):
@@ -669,8 +663,8 @@ class ChatMistralAI(BaseChatModel):
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
@@ -691,7 +685,7 @@ class ChatMistralAI(BaseChatModel):
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
tool_choice: Optional[Union[dict, str, Literal["auto", "any"]]] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
@@ -733,14 +727,14 @@ class ChatMistralAI(BaseChatModel):
def with_structured_output(
self,
schema: Optional[Union[Dict, Type]] = None,
schema: Optional[Union[dict, type]] = None,
*,
method: Literal[
"function_calling", "json_mode", "json_schema"
] = "function_calling",
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
@@ -1048,7 +1042,7 @@ class ChatMistralAI(BaseChatModel):
return llm | output_parser
@property
def _identifying_params(self) -> Dict[str, Any]:
def _identifying_params(self) -> dict[str, Any]:
"""Get the identifying parameters."""
return self._default_params
@@ -1058,7 +1052,7 @@ class ChatMistralAI(BaseChatModel):
return "mistralai-chat"
@property
def lc_secrets(self) -> Dict[str, str]:
def lc_secrets(self) -> dict[str, str]:
return {"mistral_api_key": "MISTRAL_API_KEY"}
@classmethod
@@ -1067,14 +1061,14 @@ class ChatMistralAI(BaseChatModel):
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "mistralai"]
def _convert_to_openai_response_format(
schema: Union[Dict[str, Any], Type], *, strict: Optional[bool] = None
) -> Dict:
schema: Union[dict[str, Any], type], *, strict: Optional[bool] = None
) -> dict:
"""Same as in ChatOpenAI, but don't pass through Pydantic BaseModels."""
if (
isinstance(schema, dict)
@@ -1094,8 +1088,10 @@ def _convert_to_openai_response_format(
function["schema"] = function.pop("parameters")
response_format = {"type": "json_schema", "json_schema": function}
if strict is not None and strict is not response_format["json_schema"].get(
"strict"
if (
strict is not None
and strict is not response_format["json_schema"].get("strict")
and isinstance(schema, dict)
):
msg = (
f"Output schema already has 'strict' value set to "

View File

@@ -1,7 +1,7 @@
import asyncio
import logging
import warnings
from typing import Iterable, List
from collections.abc import Iterable
import httpx
from httpx import Response
@@ -33,7 +33,7 @@ class DummyTokenizer:
"""Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface)"""
@staticmethod
def encode_batch(texts: List[str]) -> List[List[str]]:
def encode_batch(texts: list[str]) -> list[list[str]]:
return [list(text) for text in texts]
@@ -177,7 +177,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
self.tokenizer = Tokenizer.from_pretrained(
"mistralai/Mixtral-8x7B-v0.1"
)
except IOError: # huggingface_hub GatedRepoError
except OSError: # huggingface_hub GatedRepoError
warnings.warn(
"Could not download mistral tokenizer from Huggingface for "
"calculating batch sizes. Set a Huggingface token via the "
@@ -187,10 +187,10 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
self.tokenizer = DummyTokenizer()
return self
def _get_batches(self, texts: List[str]) -> Iterable[List[str]]:
def _get_batches(self, texts: list[str]) -> Iterable[list[str]]:
"""Split a list of texts into batches of less than 16k tokens
for Mistral API."""
batch: List[str] = []
batch: list[str] = []
batch_tokens = 0
text_token_lengths = [
@@ -211,7 +211,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
if batch:
yield batch
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed a list of document texts.
Args:
@@ -230,7 +230,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
wait=wait_fixed(self.wait_time),
stop=stop_after_attempt(self.max_retries),
)
def _embed_batch(batch: List[str]) -> Response:
def _embed_batch(batch: list[str]) -> Response:
response = self.client.post(
url="/embeddings",
json=dict(
@@ -252,7 +252,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
logger.error(f"An error occurred with MistralAI: {e}")
raise
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed a list of document texts.
Args:
@@ -283,7 +283,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
logger.error(f"An error occurred with MistralAI: {e}")
raise
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> list[float]:
"""Embed a single query text.
Args:
@@ -294,7 +294,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
"""
return self.embed_documents([text])[0]
async def aembed_query(self, text: str) -> List[float]:
async def aembed_query(self, text: str) -> list[float]:
"""Embed a single query text.
Args: