mistralai[patch]: ruff fixes and rules (#31918)

* bump ruff deps
* add more thorough ruff rules
* fix said rules
This commit is contained in:
Mason Daugherty
2025-07-08 12:44:42 -04:00
committed by GitHub
parent ae210c1590
commit cbb418b4bf
10 changed files with 214 additions and 143 deletions

View File

@@ -94,8 +94,7 @@ def _create_retry_decorator(
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Returns a tenacity retry decorator, preconfigured to handle exceptions"""
"""Return a tenacity retry decorator, preconfigured to handle exceptions."""
errors = [httpx.RequestError, httpx.StreamError]
return create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
@@ -103,12 +102,12 @@ def _create_retry_decorator(
def _is_valid_mistral_tool_call_id(tool_call_id: str) -> bool:
"""Check if tool call ID is nine character string consisting of a-z, A-Z, 0-9"""
"""Check if tool call ID is nine character string consisting of a-z, A-Z, 0-9."""
return bool(TOOL_CALL_ID_PATTERN.match(tool_call_id))
def _base62_encode(num: int) -> str:
"""Encodes a number in base62 and ensures result is of a specified length."""
"""Encode a number in base62 and ensures result is of a specified length."""
base62 = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
if num == 0:
return base62[0]
@@ -122,17 +121,15 @@ def _base62_encode(num: int) -> str:
def _convert_tool_call_id_to_mistral_compatible(tool_call_id: str) -> str:
"""Convert a tool call ID to a Mistral-compatible format"""
"""Convert a tool call ID to a Mistral-compatible format."""
if _is_valid_mistral_tool_call_id(tool_call_id):
return tool_call_id
else:
hash_bytes = hashlib.sha256(tool_call_id.encode()).digest()
hash_int = int.from_bytes(hash_bytes, byteorder="big")
base62_str = _base62_encode(hash_int)
if len(base62_str) >= 9:
return base62_str[:9]
else:
return base62_str.rjust(9, "0")
hash_bytes = hashlib.sha256(tool_call_id.encode()).digest()
hash_int = int.from_bytes(hash_bytes, byteorder="big")
base62_str = _base62_encode(hash_int)
if len(base62_str) >= 9:
return base62_str[:9]
return base62_str.rjust(9, "0")
def _convert_mistral_chat_message_to_message(
@@ -140,7 +137,8 @@ def _convert_mistral_chat_message_to_message(
) -> BaseMessage:
role = _message["role"]
if role != "assistant":
raise ValueError(f"Expected role to be 'assistant', got {role}")
msg = f"Expected role to be 'assistant', got {role}"
raise ValueError(msg)
content = cast(str, _message["content"])
additional_kwargs: dict = {}
@@ -170,9 +168,12 @@ def _raise_on_error(response: httpx.Response) -> None:
"""Raise an error if the response is an error."""
if httpx.codes.is_error(response.status_code):
error_message = response.read().decode("utf-8")
raise httpx.HTTPStatusError(
msg = (
f"Error response {response.status_code} "
f"while fetching {response.url}: {error_message}",
f"while fetching {response.url}: {error_message}"
)
raise httpx.HTTPStatusError(
msg,
request=response.request,
response=response,
)
@@ -182,9 +183,12 @@ async def _araise_on_error(response: httpx.Response) -> None:
"""Raise an error if the response is an error."""
if httpx.codes.is_error(response.status_code):
error_message = (await response.aread()).decode("utf-8")
raise httpx.HTTPStatusError(
msg = (
f"Error response {response.status_code} "
f"while fetching {response.url}: {error_message}",
f"while fetching {response.url}: {error_message}"
)
raise httpx.HTTPStatusError(
msg,
request=response.request,
response=response,
)
@@ -220,10 +224,9 @@ async def acompletion_with_retry(
llm.async_client, "POST", "/chat/completions", json=kwargs
)
return _aiter_sse(event_source)
else:
response = await llm.async_client.post(url="/chat/completions", json=kwargs)
await _araise_on_error(response)
return response.json()
response = await llm.async_client.post(url="/chat/completions", json=kwargs)
await _araise_on_error(response)
return response.json()
return await _completion_with_retry(**kwargs)
@@ -237,7 +240,7 @@ def _convert_chunk_to_message_chunk(
content = _delta.get("content") or ""
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
if role == "assistant" or default_class == AIMessageChunk:
additional_kwargs: dict = {}
response_metadata = {}
if raw_tool_calls := _delta.get("tool_calls"):
@@ -281,12 +284,11 @@ def _convert_chunk_to_message_chunk(
usage_metadata=usage_metadata, # type: ignore[arg-type]
response_metadata=response_metadata,
)
elif role == "system" or default_class == SystemMessageChunk:
if role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role or default_class == ChatMessageChunk:
if role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content) # type: ignore[call-arg]
return default_class(content=content) # type: ignore[call-arg]
def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
@@ -321,18 +323,24 @@ def _convert_message_to_mistral_chat_message(
message: BaseMessage,
) -> dict:
if isinstance(message, ChatMessage):
return dict(role=message.role, content=message.content)
elif isinstance(message, HumanMessage):
return dict(role="user", content=message.content)
elif isinstance(message, AIMessage):
return {"role": message.role, "content": message.content}
if isinstance(message, HumanMessage):
return {"role": "user", "content": message.content}
if isinstance(message, AIMessage):
message_dict: dict[str, Any] = {"role": "assistant"}
tool_calls = []
if message.tool_calls or message.invalid_tool_calls:
for tool_call in message.tool_calls:
tool_calls.append(_format_tool_call_for_mistral(tool_call))
tool_calls.extend(
[
_format_tool_call_for_mistral(tool_call)
for tool_call in message.tool_calls
]
)
for invalid_tool_call in message.invalid_tool_calls:
tool_calls.append(
tool_calls.extend(
_format_invalid_tool_call_for_mistral(invalid_tool_call)
for invalid_tool_call in message.invalid_tool_calls
)
elif "tool_calls" in message.additional_kwargs:
for tc in message.additional_kwargs["tool_calls"]:
@@ -359,9 +367,9 @@ def _convert_message_to_mistral_chat_message(
if "prefix" in message.additional_kwargs:
message_dict["prefix"] = message.additional_kwargs["prefix"]
return message_dict
elif isinstance(message, SystemMessage):
return dict(role="system", content=message.content)
elif isinstance(message, ToolMessage):
if isinstance(message, SystemMessage):
return {"role": "system", "content": message.content}
if isinstance(message, ToolMessage):
return {
"role": "tool",
"content": message.content,
@@ -370,8 +378,8 @@ def _convert_message_to_mistral_chat_message(
message.tool_call_id
),
}
else:
raise ValueError(f"Got unknown type {message}")
msg = f"Got unknown type {message}"
raise ValueError(msg)
class ChatMistralAI(BaseChatModel):
@@ -380,10 +388,10 @@ class ChatMistralAI(BaseChatModel):
# The type for client and async_client is ignored because the type is not
# an Optional after the model is initialized and the model_validator
# is run.
client: httpx.Client = Field( # type: ignore # : meta private:
client: httpx.Client = Field( # type: ignore[assignment] # : meta private:
default=None, exclude=True
)
async_client: httpx.AsyncClient = Field( # type: ignore # : meta private:
async_client: httpx.AsyncClient = Field( # type: ignore[assignment] # : meta private:
default=None, exclude=True
) #: :meta private:
mistral_api_key: Optional[SecretStr] = Field(
@@ -417,8 +425,7 @@ class ChatMistralAI(BaseChatModel):
def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
values = _build_model_kwargs(values, all_required_field_names)
return values
return _build_model_kwargs(values, all_required_field_names)
@property
def _default_params(self) -> dict[str, Any]:
@@ -432,8 +439,7 @@ class ChatMistralAI(BaseChatModel):
"safe_prompt": self.safe_mode,
**self.model_kwargs,
}
filtered = {k: v for k, v in defaults.items() if v is not None}
return filtered
return {k: v for k, v in defaults.items() if v is not None}
def _get_ls_params(
self, stop: Optional[list[str]] = None, **kwargs: Any
@@ -481,13 +487,11 @@ class ChatMistralAI(BaseChatModel):
yield event.json()
return iter_sse()
else:
response = self.client.post(url="/chat/completions", json=kwargs)
_raise_on_error(response)
return response.json()
response = self.client.post(url="/chat/completions", json=kwargs)
_raise_on_error(response)
return response.json()
rtn = _completion_with_retry(**kwargs)
return rtn
return _completion_with_retry(**kwargs)
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
@@ -502,8 +506,7 @@ class ChatMistralAI(BaseChatModel):
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
combined = {"token_usage": overall_token_usage, "model_name": self.model}
return combined
return {"token_usage": overall_token_usage, "model_name": self.model}
@model_validator(mode="after")
def validate_environment(self) -> Self:
@@ -545,10 +548,12 @@ class ChatMistralAI(BaseChatModel):
)
if self.temperature is not None and not 0 <= self.temperature <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")
msg = "temperature must be in the range [0.0, 1.0]"
raise ValueError(msg)
if self.top_p is not None and not 0 <= self.top_p <= 1:
raise ValueError("top_p must be in the range [0.0, 1.0]")
msg = "top_p must be in the range [0.0, 1.0]"
raise ValueError(msg)
return self
@@ -557,7 +562,7 @@ class ChatMistralAI(BaseChatModel):
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
stream: Optional[bool] = None, # noqa: FBT001
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
@@ -669,7 +674,7 @@ class ChatMistralAI(BaseChatModel):
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
stream: Optional[bool] = None, # noqa: FBT001
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
@@ -689,7 +694,7 @@ class ChatMistralAI(BaseChatModel):
def bind_tools(
self,
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
tool_choice: Optional[Union[dict, str, Literal["auto", "any"]]] = None,
tool_choice: Optional[Union[dict, str, Literal["auto", "any"]]] = None, # noqa: PYI051
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
@@ -707,15 +712,15 @@ class ChatMistralAI(BaseChatModel):
{"type": "function", "function": {"name": <<tool_name>>}}.
kwargs: Any additional parameters are passed directly to
``self.bind(**kwargs)``.
"""
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
if tool_choice:
tool_names = []
for tool in formatted_tools:
if "function" in tool and (name := tool["function"].get("name")):
tool_names.append(name)
elif name := tool.get("name"):
if ("function" in tool and (name := tool["function"].get("name"))) or (
name := tool.get("name")
):
tool_names.append(name)
else:
pass
@@ -738,7 +743,7 @@ class ChatMistralAI(BaseChatModel):
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
r"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema:
@@ -785,6 +790,12 @@ class ChatMistralAI(BaseChatModel):
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
kwargs: Any additional parameters are passed directly to
``self.bind(**kwargs)``. This is useful for passing in
parameters such as ``tool_choice`` or ``tools`` to control
which tool the model should call, or to pass in parameters such as
``stop`` to control when the model should stop generating output.
Returns:
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
@@ -968,14 +979,16 @@ class ChatMistralAI(BaseChatModel):
""" # noqa: E501
_ = kwargs.pop("strict", None)
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
msg = f"Received unsupported arguments {kwargs}"
raise ValueError(msg)
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
if method == "function_calling":
if schema is None:
raise ValueError(
msg = (
"schema must be specified when method is 'function_calling'. "
"Received None."
)
raise ValueError(msg)
# TODO: Update to pass in tool name as tool_choice if/when Mistral supports
# specifying a tool.
llm = self.bind_tools(
@@ -1014,10 +1027,11 @@ class ChatMistralAI(BaseChatModel):
)
elif method == "json_schema":
if schema is None:
raise ValueError(
msg = (
"schema must be specified when method is 'json_schema'. "
"Received None."
)
raise ValueError(msg)
response_format = _convert_to_openai_response_format(schema, strict=True)
llm = self.bind(
response_format=response_format,
@@ -1041,8 +1055,7 @@ class ChatMistralAI(BaseChatModel):
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
return llm | output_parser
@property
def _identifying_params(self) -> dict[str, Any]:
@@ -1072,7 +1085,7 @@ class ChatMistralAI(BaseChatModel):
def _convert_to_openai_response_format(
schema: Union[dict[str, Any], type], *, strict: Optional[bool] = None
) -> dict:
"""Same as in ChatOpenAI, but don't pass through Pydantic BaseModels."""
"""Perform same op as in ChatOpenAI, but do not pass through Pydantic BaseModels."""
if (
isinstance(schema, dict)
and "json_schema" in schema

View File

@@ -17,20 +17,20 @@ from pydantic import (
model_validator,
)
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
from tokenizers import Tokenizer # type: ignore
from tokenizers import Tokenizer # type: ignore[import]
from typing_extensions import Self
logger = logging.getLogger(__name__)
MAX_TOKENS = 16_000
"""A batching parameter for the Mistral API. This is NOT the maximum number of tokens
accepted by the embedding model for each document/chunk, but rather the maximum number
accepted by the embedding model for each document/chunk, but rather the maximum number
of tokens that can be sent in a single request to the Mistral API (across multiple
documents/chunks)"""
class DummyTokenizer:
"""Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface)"""
"""Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface)."""
@staticmethod
def encode_batch(texts: list[str]) -> list[list[str]]:
@@ -126,9 +126,9 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
# The type for client and async_client is ignored because the type is not
# an Optional after the model is initialized and the model_validator
# is run.
client: httpx.Client = Field(default=None) # type: ignore # : :meta private:
client: httpx.Client = Field(default=None) # type: ignore[assignment] # :meta private:
async_client: httpx.AsyncClient = Field( # type: ignore # : meta private:
async_client: httpx.AsyncClient = Field( # type: ignore[assignment] # :meta private:
default=None
)
mistral_api_key: SecretStr = Field(
@@ -153,7 +153,6 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate configuration."""
api_key_str = self.mistral_api_key.get_secret_value()
# todo: handle retries
if not self.client:
@@ -187,14 +186,14 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
"Could not download mistral tokenizer from Huggingface for "
"calculating batch sizes. Set a Huggingface token via the "
"HF_TOKEN environment variable to download the real tokenizer. "
"Falling back to a dummy tokenizer that uses `len()`."
"Falling back to a dummy tokenizer that uses `len()`.",
stacklevel=2,
)
self.tokenizer = DummyTokenizer()
return self
def _get_batches(self, texts: list[str]) -> Iterable[list[str]]:
"""Split a list of texts into batches of less than 16k tokens for Mistral
API."""
"""Split list of texts into batches of less than 16k tokens for Mistral API."""
batch: list[str] = []
batch_tokens = 0
@@ -224,6 +223,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
Returns:
List of embeddings, one for each text.
"""
try:
batch_responses = []
@@ -238,16 +238,17 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
def _embed_batch(batch: list[str]) -> Response:
response = self.client.post(
url="/embeddings",
json=dict(
model=self.model,
input=batch,
),
json={
"model": self.model,
"input": batch,
},
)
response.raise_for_status()
return response
for batch in self._get_batches(texts):
batch_responses.append(_embed_batch(batch))
batch_responses = [
_embed_batch(batch) for batch in self._get_batches(texts)
]
return [
list(map(float, embedding_obj["embedding"]))
for response in batch_responses
@@ -265,16 +266,17 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
Returns:
List of embeddings, one for each text.
"""
try:
batch_responses = await asyncio.gather(
*[
self.async_client.post(
url="/embeddings",
json=dict(
model=self.model,
input=batch,
),
json={
"model": self.model,
"input": batch,
},
)
for batch in self._get_batches(texts)
]
@@ -296,6 +298,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
Returns:
Embedding for the text.
"""
return self.embed_documents([text])[0]
@@ -307,5 +310,6 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
Returns:
Embedding for the text.
"""
return (await self.aembed_documents([text]))[0]