mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
partners[lint]: run pyupgrade to get code in line with 3.9 standards (#30781)
Using `pyupgrade` to get all `partners` code up to 3.9 standards (mostly, fixing old `typing` imports).
This commit is contained in:
@@ -7,20 +7,14 @@ import os
|
||||
import re
|
||||
import ssl
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@@ -142,13 +136,13 @@ def _convert_tool_call_id_to_mistral_compatible(tool_call_id: str) -> str:
|
||||
|
||||
|
||||
def _convert_mistral_chat_message_to_message(
|
||||
_message: Dict,
|
||||
_message: dict,
|
||||
) -> BaseMessage:
|
||||
role = _message["role"]
|
||||
assert role == "assistant", f"Expected role to be 'assistant', got {role}"
|
||||
content = cast(str, _message["content"])
|
||||
|
||||
additional_kwargs: Dict = {}
|
||||
additional_kwargs: dict = {}
|
||||
tool_calls = []
|
||||
invalid_tool_calls = []
|
||||
if raw_tool_calls := _message.get("tool_calls"):
|
||||
@@ -196,8 +190,8 @@ async def _araise_on_error(response: httpx.Response) -> None:
|
||||
|
||||
|
||||
async def _aiter_sse(
|
||||
event_source_mgr: AsyncContextManager[EventSource],
|
||||
) -> AsyncIterator[Dict]:
|
||||
event_source_mgr: AbstractAsyncContextManager[EventSource],
|
||||
) -> AsyncIterator[dict]:
|
||||
"""Iterate over the server-sent events."""
|
||||
async with event_source_mgr as event_source:
|
||||
await _araise_on_error(event_source.response)
|
||||
@@ -234,7 +228,7 @@ async def acompletion_with_retry(
|
||||
|
||||
|
||||
def _convert_chunk_to_message_chunk(
|
||||
chunk: Dict, default_class: Type[BaseMessageChunk]
|
||||
chunk: dict, default_class: type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
_choice = chunk["choices"][0]
|
||||
_delta = _choice["delta"]
|
||||
@@ -243,7 +237,7 @@ def _convert_chunk_to_message_chunk(
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
additional_kwargs: Dict = {}
|
||||
additional_kwargs: dict = {}
|
||||
response_metadata = {}
|
||||
if raw_tool_calls := _delta.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||
@@ -295,7 +289,7 @@ def _convert_chunk_to_message_chunk(
|
||||
|
||||
def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
|
||||
"""Format Langchain ToolCall to dict expected by Mistral."""
|
||||
result: Dict[str, Any] = {
|
||||
result: dict[str, Any] = {
|
||||
"function": {
|
||||
"name": tool_call["name"],
|
||||
"arguments": json.dumps(tool_call["args"]),
|
||||
@@ -309,7 +303,7 @@ def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
|
||||
|
||||
def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) -> dict:
|
||||
"""Format Langchain InvalidToolCall to dict expected by Mistral."""
|
||||
result: Dict[str, Any] = {
|
||||
result: dict[str, Any] = {
|
||||
"function": {
|
||||
"name": invalid_tool_call["name"],
|
||||
"arguments": invalid_tool_call["args"],
|
||||
@@ -323,13 +317,13 @@ def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) ->
|
||||
|
||||
def _convert_message_to_mistral_chat_message(
|
||||
message: BaseMessage,
|
||||
) -> Dict:
|
||||
) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
return dict(role=message.role, content=message.content)
|
||||
elif isinstance(message, HumanMessage):
|
||||
return dict(role="user", content=message.content)
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict: Dict[str, Any] = {"role": "assistant"}
|
||||
message_dict: dict[str, Any] = {"role": "assistant"}
|
||||
tool_calls = []
|
||||
if message.tool_calls or message.invalid_tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
@@ -407,7 +401,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
random_seed: Optional[int] = None
|
||||
safe_mode: Optional[bool] = None
|
||||
streaming: bool = False
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any invocation parameters not explicitly specified."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
@@ -417,14 +411,14 @@ class ChatMistralAI(BaseChatModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
def build_extra(cls, values: dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
values = _build_model_kwargs(values, all_required_field_names)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
def _default_params(self) -> dict[str, Any]:
|
||||
"""Get the default parameters for calling the API."""
|
||||
defaults = {
|
||||
"model": self.model,
|
||||
@@ -439,7 +433,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
return filtered
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
self, stop: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
@@ -456,7 +450,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
return ls_params
|
||||
|
||||
@property
|
||||
def _client_params(self) -> Dict[str, Any]:
|
||||
def _client_params(self) -> dict[str, Any]:
|
||||
"""Get the parameters used for the client."""
|
||||
return self._default_params
|
||||
|
||||
@@ -473,7 +467,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
stream = kwargs["stream"]
|
||||
if stream:
|
||||
|
||||
def iter_sse() -> Iterator[Dict]:
|
||||
def iter_sse() -> Iterator[dict]:
|
||||
with connect_sse(
|
||||
self.client, "POST", "/chat/completions", json=kwargs
|
||||
) as event_source:
|
||||
@@ -492,7 +486,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
rtn = _completion_with_retry(**kwargs)
|
||||
return rtn
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
|
||||
overall_token_usage: dict = {}
|
||||
for output in llm_outputs:
|
||||
if output is None:
|
||||
@@ -557,8 +551,8 @@ class ChatMistralAI(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
@@ -577,7 +571,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _create_chat_result(self, response: Dict) -> ChatResult:
|
||||
def _create_chat_result(self, response: dict) -> ChatResult:
|
||||
generations = []
|
||||
token_usage = response.get("usage", {})
|
||||
for res in response["choices"]:
|
||||
@@ -603,8 +597,8 @@ class ChatMistralAI(BaseChatModel):
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict], Dict[str, Any]]:
|
||||
self, messages: list[BaseMessage], stop: Optional[list[str]]
|
||||
) -> tuple[list[dict], dict[str, Any]]:
|
||||
params = self._client_params
|
||||
if stop is not None or "stop" in params:
|
||||
if "stop" in params:
|
||||
@@ -617,15 +611,15 @@ class ChatMistralAI(BaseChatModel):
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
for chunk in self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
@@ -643,15 +637,15 @@ class ChatMistralAI(BaseChatModel):
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
async for chunk in await acompletion_with_retry(
|
||||
self, messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
@@ -669,8 +663,8 @@ class ChatMistralAI(BaseChatModel):
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
@@ -691,7 +685,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
|
||||
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
|
||||
tool_choice: Optional[Union[dict, str, Literal["auto", "any"]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
@@ -733,14 +727,14 @@ class ChatMistralAI(BaseChatModel):
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Optional[Union[Dict, Type]] = None,
|
||||
schema: Optional[Union[dict, type]] = None,
|
||||
*,
|
||||
method: Literal[
|
||||
"function_calling", "json_mode", "json_schema"
|
||||
] = "function_calling",
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
@@ -1048,7 +1042,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
return llm | output_parser
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
def _identifying_params(self) -> dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return self._default_params
|
||||
|
||||
@@ -1058,7 +1052,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
return "mistralai-chat"
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
def lc_secrets(self) -> dict[str, str]:
|
||||
return {"mistral_api_key": "MISTRAL_API_KEY"}
|
||||
|
||||
@classmethod
|
||||
@@ -1067,14 +1061,14 @@ class ChatMistralAI(BaseChatModel):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "chat_models", "mistralai"]
|
||||
|
||||
|
||||
def _convert_to_openai_response_format(
|
||||
schema: Union[Dict[str, Any], Type], *, strict: Optional[bool] = None
|
||||
) -> Dict:
|
||||
schema: Union[dict[str, Any], type], *, strict: Optional[bool] = None
|
||||
) -> dict:
|
||||
"""Same as in ChatOpenAI, but don't pass through Pydantic BaseModels."""
|
||||
if (
|
||||
isinstance(schema, dict)
|
||||
@@ -1094,8 +1088,10 @@ def _convert_to_openai_response_format(
|
||||
function["schema"] = function.pop("parameters")
|
||||
response_format = {"type": "json_schema", "json_schema": function}
|
||||
|
||||
if strict is not None and strict is not response_format["json_schema"].get(
|
||||
"strict"
|
||||
if (
|
||||
strict is not None
|
||||
and strict is not response_format["json_schema"].get("strict")
|
||||
and isinstance(schema, dict)
|
||||
):
|
||||
msg = (
|
||||
f"Output schema already has 'strict' value set to "
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Iterable, List
|
||||
from collections.abc import Iterable
|
||||
|
||||
import httpx
|
||||
from httpx import Response
|
||||
@@ -33,7 +33,7 @@ class DummyTokenizer:
|
||||
"""Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface)"""
|
||||
|
||||
@staticmethod
|
||||
def encode_batch(texts: List[str]) -> List[List[str]]:
|
||||
def encode_batch(texts: list[str]) -> list[list[str]]:
|
||||
return [list(text) for text in texts]
|
||||
|
||||
|
||||
@@ -177,7 +177,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
self.tokenizer = Tokenizer.from_pretrained(
|
||||
"mistralai/Mixtral-8x7B-v0.1"
|
||||
)
|
||||
except IOError: # huggingface_hub GatedRepoError
|
||||
except OSError: # huggingface_hub GatedRepoError
|
||||
warnings.warn(
|
||||
"Could not download mistral tokenizer from Huggingface for "
|
||||
"calculating batch sizes. Set a Huggingface token via the "
|
||||
@@ -187,10 +187,10 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
self.tokenizer = DummyTokenizer()
|
||||
return self
|
||||
|
||||
def _get_batches(self, texts: List[str]) -> Iterable[List[str]]:
|
||||
def _get_batches(self, texts: list[str]) -> Iterable[list[str]]:
|
||||
"""Split a list of texts into batches of less than 16k tokens
|
||||
for Mistral API."""
|
||||
batch: List[str] = []
|
||||
batch: list[str] = []
|
||||
batch_tokens = 0
|
||||
|
||||
text_token_lengths = [
|
||||
@@ -211,7 +211,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed a list of document texts.
|
||||
|
||||
Args:
|
||||
@@ -230,7 +230,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
wait=wait_fixed(self.wait_time),
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
)
|
||||
def _embed_batch(batch: List[str]) -> Response:
|
||||
def _embed_batch(batch: list[str]) -> Response:
|
||||
response = self.client.post(
|
||||
url="/embeddings",
|
||||
json=dict(
|
||||
@@ -252,7 +252,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
logger.error(f"An error occurred with MistralAI: {e}")
|
||||
raise
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed a list of document texts.
|
||||
|
||||
Args:
|
||||
@@ -283,7 +283,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
logger.error(f"An error occurred with MistralAI: {e}")
|
||||
raise
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed a single query text.
|
||||
|
||||
Args:
|
||||
@@ -294,7 +294,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
"""Embed a single query text.
|
||||
|
||||
Args:
|
||||
|
||||
Reference in New Issue
Block a user