From cbb418b4bf2d35c7fc99323bcf2e5642f6a3b8cd Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Tue, 8 Jul 2025 12:44:42 -0400 Subject: [PATCH] mistralai[patch]: ruff fixes and rules (#31918) * bump ruff deps * add more thorough ruff rules * fix said rules --- .../langchain_mistralai/chat_models.py | 149 ++++++++++-------- .../langchain_mistralai/embeddings.py | 42 ++--- libs/partners/mistralai/pyproject.toml | 58 ++++++- .../integration_tests/test_chat_models.py | 23 +-- .../tests/integration_tests/test_compile.py | 1 - .../integration_tests/test_embeddings.py | 2 +- .../tests/integration_tests/test_standard.py | 2 +- .../tests/unit_tests/test_chat_models.py | 36 ++--- .../tests/unit_tests/test_standard.py | 2 +- libs/partners/mistralai/uv.lock | 42 ++--- 10 files changed, 214 insertions(+), 143 deletions(-) diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index cdbc5d37bb8..358f4ea73a4 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -94,8 +94,7 @@ def _create_retry_decorator( Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] ] = None, ) -> Callable[[Any], Any]: - """Returns a tenacity retry decorator, preconfigured to handle exceptions""" - + """Return a tenacity retry decorator, preconfigured to handle exceptions.""" errors = [httpx.RequestError, httpx.StreamError] return create_base_retry_decorator( error_types=errors, max_retries=llm.max_retries, run_manager=run_manager @@ -103,12 +102,12 @@ def _create_retry_decorator( def _is_valid_mistral_tool_call_id(tool_call_id: str) -> bool: - """Check if tool call ID is nine character string consisting of a-z, A-Z, 0-9""" + """Check if tool call ID is nine character string consisting of a-z, A-Z, 0-9.""" return bool(TOOL_CALL_ID_PATTERN.match(tool_call_id)) def _base62_encode(num: int) -> str: - """Encodes a number in base62 and ensures result is of a specified length.""" + """Encode a number in base62 and ensures result is of a specified length.""" base62 = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" if num == 0: return base62[0] @@ -122,17 +121,15 @@ def _base62_encode(num: int) -> str: def _convert_tool_call_id_to_mistral_compatible(tool_call_id: str) -> str: - """Convert a tool call ID to a Mistral-compatible format""" + """Convert a tool call ID to a Mistral-compatible format.""" if _is_valid_mistral_tool_call_id(tool_call_id): return tool_call_id - else: - hash_bytes = hashlib.sha256(tool_call_id.encode()).digest() - hash_int = int.from_bytes(hash_bytes, byteorder="big") - base62_str = _base62_encode(hash_int) - if len(base62_str) >= 9: - return base62_str[:9] - else: - return base62_str.rjust(9, "0") + hash_bytes = hashlib.sha256(tool_call_id.encode()).digest() + hash_int = int.from_bytes(hash_bytes, byteorder="big") + base62_str = _base62_encode(hash_int) + if len(base62_str) >= 9: + return base62_str[:9] + return base62_str.rjust(9, "0") def _convert_mistral_chat_message_to_message( @@ -140,7 +137,8 @@ def _convert_mistral_chat_message_to_message( ) -> BaseMessage: role = _message["role"] if role != "assistant": - raise ValueError(f"Expected role to be 'assistant', got {role}") + msg = f"Expected role to be 'assistant', got {role}" + raise ValueError(msg) content = cast(str, _message["content"]) additional_kwargs: dict = {} @@ -170,9 +168,12 @@ def _raise_on_error(response: httpx.Response) -> None: """Raise an error if the response is an error.""" if httpx.codes.is_error(response.status_code): error_message = response.read().decode("utf-8") - raise httpx.HTTPStatusError( + msg = ( f"Error response {response.status_code} " - f"while fetching {response.url}: {error_message}", + f"while fetching {response.url}: {error_message}" + ) + raise httpx.HTTPStatusError( + msg, request=response.request, response=response, ) @@ -182,9 +183,12 @@ async def _araise_on_error(response: httpx.Response) -> None: """Raise an error if the response is an error.""" if httpx.codes.is_error(response.status_code): error_message = (await response.aread()).decode("utf-8") - raise httpx.HTTPStatusError( + msg = ( f"Error response {response.status_code} " - f"while fetching {response.url}: {error_message}", + f"while fetching {response.url}: {error_message}" + ) + raise httpx.HTTPStatusError( + msg, request=response.request, response=response, ) @@ -220,10 +224,9 @@ async def acompletion_with_retry( llm.async_client, "POST", "/chat/completions", json=kwargs ) return _aiter_sse(event_source) - else: - response = await llm.async_client.post(url="/chat/completions", json=kwargs) - await _araise_on_error(response) - return response.json() + response = await llm.async_client.post(url="/chat/completions", json=kwargs) + await _araise_on_error(response) + return response.json() return await _completion_with_retry(**kwargs) @@ -237,7 +240,7 @@ def _convert_chunk_to_message_chunk( content = _delta.get("content") or "" if role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=content) - elif role == "assistant" or default_class == AIMessageChunk: + if role == "assistant" or default_class == AIMessageChunk: additional_kwargs: dict = {} response_metadata = {} if raw_tool_calls := _delta.get("tool_calls"): @@ -281,12 +284,11 @@ def _convert_chunk_to_message_chunk( usage_metadata=usage_metadata, # type: ignore[arg-type] response_metadata=response_metadata, ) - elif role == "system" or default_class == SystemMessageChunk: + if role == "system" or default_class == SystemMessageChunk: return SystemMessageChunk(content=content) - elif role or default_class == ChatMessageChunk: + if role or default_class == ChatMessageChunk: return ChatMessageChunk(content=content, role=role) - else: - return default_class(content=content) # type: ignore[call-arg] + return default_class(content=content) # type: ignore[call-arg] def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict: @@ -321,18 +323,24 @@ def _convert_message_to_mistral_chat_message( message: BaseMessage, ) -> dict: if isinstance(message, ChatMessage): - return dict(role=message.role, content=message.content) - elif isinstance(message, HumanMessage): - return dict(role="user", content=message.content) - elif isinstance(message, AIMessage): + return {"role": message.role, "content": message.content} + if isinstance(message, HumanMessage): + return {"role": "user", "content": message.content} + if isinstance(message, AIMessage): message_dict: dict[str, Any] = {"role": "assistant"} tool_calls = [] if message.tool_calls or message.invalid_tool_calls: for tool_call in message.tool_calls: - tool_calls.append(_format_tool_call_for_mistral(tool_call)) + tool_calls.extend( + [ + _format_tool_call_for_mistral(tool_call) + for tool_call in message.tool_calls + ] + ) for invalid_tool_call in message.invalid_tool_calls: - tool_calls.append( + tool_calls.extend( _format_invalid_tool_call_for_mistral(invalid_tool_call) + for invalid_tool_call in message.invalid_tool_calls ) elif "tool_calls" in message.additional_kwargs: for tc in message.additional_kwargs["tool_calls"]: @@ -359,9 +367,9 @@ def _convert_message_to_mistral_chat_message( if "prefix" in message.additional_kwargs: message_dict["prefix"] = message.additional_kwargs["prefix"] return message_dict - elif isinstance(message, SystemMessage): - return dict(role="system", content=message.content) - elif isinstance(message, ToolMessage): + if isinstance(message, SystemMessage): + return {"role": "system", "content": message.content} + if isinstance(message, ToolMessage): return { "role": "tool", "content": message.content, @@ -370,8 +378,8 @@ def _convert_message_to_mistral_chat_message( message.tool_call_id ), } - else: - raise ValueError(f"Got unknown type {message}") + msg = f"Got unknown type {message}" + raise ValueError(msg) class ChatMistralAI(BaseChatModel): @@ -380,10 +388,10 @@ class ChatMistralAI(BaseChatModel): # The type for client and async_client is ignored because the type is not # an Optional after the model is initialized and the model_validator # is run. - client: httpx.Client = Field( # type: ignore # : meta private: + client: httpx.Client = Field( # type: ignore[assignment] # : meta private: default=None, exclude=True ) - async_client: httpx.AsyncClient = Field( # type: ignore # : meta private: + async_client: httpx.AsyncClient = Field( # type: ignore[assignment] # : meta private: default=None, exclude=True ) #: :meta private: mistral_api_key: Optional[SecretStr] = Field( @@ -417,8 +425,7 @@ class ChatMistralAI(BaseChatModel): def build_extra(cls, values: dict[str, Any]) -> Any: """Build extra kwargs from additional params that were passed in.""" all_required_field_names = get_pydantic_field_names(cls) - values = _build_model_kwargs(values, all_required_field_names) - return values + return _build_model_kwargs(values, all_required_field_names) @property def _default_params(self) -> dict[str, Any]: @@ -432,8 +439,7 @@ class ChatMistralAI(BaseChatModel): "safe_prompt": self.safe_mode, **self.model_kwargs, } - filtered = {k: v for k, v in defaults.items() if v is not None} - return filtered + return {k: v for k, v in defaults.items() if v is not None} def _get_ls_params( self, stop: Optional[list[str]] = None, **kwargs: Any @@ -481,13 +487,11 @@ class ChatMistralAI(BaseChatModel): yield event.json() return iter_sse() - else: - response = self.client.post(url="/chat/completions", json=kwargs) - _raise_on_error(response) - return response.json() + response = self.client.post(url="/chat/completions", json=kwargs) + _raise_on_error(response) + return response.json() - rtn = _completion_with_retry(**kwargs) - return rtn + return _completion_with_retry(**kwargs) def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict: overall_token_usage: dict = {} @@ -502,8 +506,7 @@ class ChatMistralAI(BaseChatModel): overall_token_usage[k] += v else: overall_token_usage[k] = v - combined = {"token_usage": overall_token_usage, "model_name": self.model} - return combined + return {"token_usage": overall_token_usage, "model_name": self.model} @model_validator(mode="after") def validate_environment(self) -> Self: @@ -545,10 +548,12 @@ class ChatMistralAI(BaseChatModel): ) if self.temperature is not None and not 0 <= self.temperature <= 1: - raise ValueError("temperature must be in the range [0.0, 1.0]") + msg = "temperature must be in the range [0.0, 1.0]" + raise ValueError(msg) if self.top_p is not None and not 0 <= self.top_p <= 1: - raise ValueError("top_p must be in the range [0.0, 1.0]") + msg = "top_p must be in the range [0.0, 1.0]" + raise ValueError(msg) return self @@ -557,7 +562,7 @@ class ChatMistralAI(BaseChatModel): messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, - stream: Optional[bool] = None, + stream: Optional[bool] = None, # noqa: FBT001 **kwargs: Any, ) -> ChatResult: should_stream = stream if stream is not None else self.streaming @@ -669,7 +674,7 @@ class ChatMistralAI(BaseChatModel): messages: list[BaseMessage], stop: Optional[list[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - stream: Optional[bool] = None, + stream: Optional[bool] = None, # noqa: FBT001 **kwargs: Any, ) -> ChatResult: should_stream = stream if stream is not None else self.streaming @@ -689,7 +694,7 @@ class ChatMistralAI(BaseChatModel): def bind_tools( self, tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]], - tool_choice: Optional[Union[dict, str, Literal["auto", "any"]]] = None, + tool_choice: Optional[Union[dict, str, Literal["auto", "any"]]] = None, # noqa: PYI051 **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: """Bind tool-like objects to this chat model. @@ -707,15 +712,15 @@ class ChatMistralAI(BaseChatModel): {"type": "function", "function": {"name": <>}}. kwargs: Any additional parameters are passed directly to ``self.bind(**kwargs)``. - """ + """ formatted_tools = [convert_to_openai_tool(tool) for tool in tools] if tool_choice: tool_names = [] for tool in formatted_tools: - if "function" in tool and (name := tool["function"].get("name")): - tool_names.append(name) - elif name := tool.get("name"): + if ("function" in tool and (name := tool["function"].get("name"))) or ( + name := tool.get("name") + ): tool_names.append(name) else: pass @@ -738,7 +743,7 @@ class ChatMistralAI(BaseChatModel): include_raw: bool = False, **kwargs: Any, ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]: - """Model wrapper that returns outputs formatted to match the given schema. + r"""Model wrapper that returns outputs formatted to match the given schema. Args: schema: @@ -785,6 +790,12 @@ class ChatMistralAI(BaseChatModel): will be caught and returned as well. The final output is always a dict with keys "raw", "parsed", and "parsing_error". + kwargs: Any additional parameters are passed directly to + ``self.bind(**kwargs)``. This is useful for passing in + parameters such as ``tool_choice`` or ``tools`` to control + which tool the model should call, or to pass in parameters such as + ``stop`` to control when the model should stop generating output. + Returns: A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`. @@ -968,14 +979,16 @@ class ChatMistralAI(BaseChatModel): """ # noqa: E501 _ = kwargs.pop("strict", None) if kwargs: - raise ValueError(f"Received unsupported arguments {kwargs}") + msg = f"Received unsupported arguments {kwargs}" + raise ValueError(msg) is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema) if method == "function_calling": if schema is None: - raise ValueError( + msg = ( "schema must be specified when method is 'function_calling'. " "Received None." ) + raise ValueError(msg) # TODO: Update to pass in tool name as tool_choice if/when Mistral supports # specifying a tool. llm = self.bind_tools( @@ -1014,10 +1027,11 @@ class ChatMistralAI(BaseChatModel): ) elif method == "json_schema": if schema is None: - raise ValueError( + msg = ( "schema must be specified when method is 'json_schema'. " "Received None." ) + raise ValueError(msg) response_format = _convert_to_openai_response_format(schema, strict=True) llm = self.bind( response_format=response_format, @@ -1041,8 +1055,7 @@ class ChatMistralAI(BaseChatModel): [parser_none], exception_key="parsing_error" ) return RunnableMap(raw=llm) | parser_with_fallback - else: - return llm | output_parser + return llm | output_parser @property def _identifying_params(self) -> dict[str, Any]: @@ -1072,7 +1085,7 @@ class ChatMistralAI(BaseChatModel): def _convert_to_openai_response_format( schema: Union[dict[str, Any], type], *, strict: Optional[bool] = None ) -> dict: - """Same as in ChatOpenAI, but don't pass through Pydantic BaseModels.""" + """Perform same op as in ChatOpenAI, but do not pass through Pydantic BaseModels.""" if ( isinstance(schema, dict) and "json_schema" in schema diff --git a/libs/partners/mistralai/langchain_mistralai/embeddings.py b/libs/partners/mistralai/langchain_mistralai/embeddings.py index 575ceb0c637..6eff302fa5b 100644 --- a/libs/partners/mistralai/langchain_mistralai/embeddings.py +++ b/libs/partners/mistralai/langchain_mistralai/embeddings.py @@ -17,20 +17,20 @@ from pydantic import ( model_validator, ) from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed -from tokenizers import Tokenizer # type: ignore +from tokenizers import Tokenizer # type: ignore[import] from typing_extensions import Self logger = logging.getLogger(__name__) MAX_TOKENS = 16_000 """A batching parameter for the Mistral API. This is NOT the maximum number of tokens -accepted by the embedding model for each document/chunk, but rather the maximum number +accepted by the embedding model for each document/chunk, but rather the maximum number of tokens that can be sent in a single request to the Mistral API (across multiple documents/chunks)""" class DummyTokenizer: - """Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface)""" + """Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface).""" @staticmethod def encode_batch(texts: list[str]) -> list[list[str]]: @@ -126,9 +126,9 @@ class MistralAIEmbeddings(BaseModel, Embeddings): # The type for client and async_client is ignored because the type is not # an Optional after the model is initialized and the model_validator # is run. - client: httpx.Client = Field(default=None) # type: ignore # : :meta private: + client: httpx.Client = Field(default=None) # type: ignore[assignment] # :meta private: - async_client: httpx.AsyncClient = Field( # type: ignore # : meta private: + async_client: httpx.AsyncClient = Field( # type: ignore[assignment] # :meta private: default=None ) mistral_api_key: SecretStr = Field( @@ -153,7 +153,6 @@ class MistralAIEmbeddings(BaseModel, Embeddings): @model_validator(mode="after") def validate_environment(self) -> Self: """Validate configuration.""" - api_key_str = self.mistral_api_key.get_secret_value() # todo: handle retries if not self.client: @@ -187,14 +186,14 @@ class MistralAIEmbeddings(BaseModel, Embeddings): "Could not download mistral tokenizer from Huggingface for " "calculating batch sizes. Set a Huggingface token via the " "HF_TOKEN environment variable to download the real tokenizer. " - "Falling back to a dummy tokenizer that uses `len()`." + "Falling back to a dummy tokenizer that uses `len()`.", + stacklevel=2, ) self.tokenizer = DummyTokenizer() return self def _get_batches(self, texts: list[str]) -> Iterable[list[str]]: - """Split a list of texts into batches of less than 16k tokens for Mistral - API.""" + """Split list of texts into batches of less than 16k tokens for Mistral API.""" batch: list[str] = [] batch_tokens = 0 @@ -224,6 +223,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. + """ try: batch_responses = [] @@ -238,16 +238,17 @@ class MistralAIEmbeddings(BaseModel, Embeddings): def _embed_batch(batch: list[str]) -> Response: response = self.client.post( url="/embeddings", - json=dict( - model=self.model, - input=batch, - ), + json={ + "model": self.model, + "input": batch, + }, ) response.raise_for_status() return response - for batch in self._get_batches(texts): - batch_responses.append(_embed_batch(batch)) + batch_responses = [ + _embed_batch(batch) for batch in self._get_batches(texts) + ] return [ list(map(float, embedding_obj["embedding"])) for response in batch_responses @@ -265,16 +266,17 @@ class MistralAIEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. + """ try: batch_responses = await asyncio.gather( *[ self.async_client.post( url="/embeddings", - json=dict( - model=self.model, - input=batch, - ), + json={ + "model": self.model, + "input": batch, + }, ) for batch in self._get_batches(texts) ] @@ -296,6 +298,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings): Returns: Embedding for the text. + """ return self.embed_documents([text])[0] @@ -307,5 +310,6 @@ class MistralAIEmbeddings(BaseModel, Embeddings): Returns: Embedding for the text. + """ return (await self.aembed_documents([text]))[0] diff --git a/libs/partners/mistralai/pyproject.toml b/libs/partners/mistralai/pyproject.toml index 1b521b7742b..131a599fc4b 100644 --- a/libs/partners/mistralai/pyproject.toml +++ b/libs/partners/mistralai/pyproject.toml @@ -48,8 +48,62 @@ disallow_untyped_defs = "True" target-version = "py39" [tool.ruff.lint] -select = ["E", "F", "I", "T201", "UP", "S"] -ignore = [ "UP007", ] +select = [ + "A", # flake8-builtins + "B", # flake8-bugbear + "ASYNC", # flake8-async + "C4", # flake8-comprehensions + "COM", # flake8-commas + "D", # pydocstyle + "DOC", # pydoclint + "E", # pycodestyle error + "EM", # flake8-errmsg + "F", # pyflakes + "FA", # flake8-future-annotations + "FBT", # flake8-boolean-trap + "FLY", # flake8-flynt + "I", # isort + "ICN", # flake8-import-conventions + "INT", # flake8-gettext + "ISC", # isort-comprehensions + "PGH", # pygrep-hooks + "PIE", # flake8-pie + "PERF", # flake8-perf + "PYI", # flake8-pyi + "Q", # flake8-quotes + "RET", # flake8-return + "RSE", # flake8-rst-docstrings + "RUF", # ruff + "S", # flake8-bandit + "SLF", # flake8-self + "SLOT", # flake8-slots + "SIM", # flake8-simplify + "T10", # flake8-debugger + "T20", # flake8-print + "TID", # flake8-tidy-imports + "UP", # pyupgrade + "W", # pycodestyle warning + "YTT", # flake8-2020 +] +ignore = [ + "D100", # pydocstyle: Missing docstring in public module + "D101", # pydocstyle: Missing docstring in public class + "D102", # pydocstyle: Missing docstring in public method + "D103", # pydocstyle: Missing docstring in public function + "D104", # pydocstyle: Missing docstring in public package + "D105", # pydocstyle: Missing docstring in magic method + "D107", # pydocstyle: Missing docstring in __init__ + "D203", # Messes with the formatter + "D407", # pydocstyle: Missing-dashed-underline-after-section + "COM812", # Messes with the formatter + "ISC001", # Messes with the formatter + "PERF203", # Rarely useful + "S112", # Rarely useful + "RUF012", # Doesn't play well with Pydantic + "SLF001", # Private member access + "UP007", # pyupgrade: non-pep604-annotation-union + "UP045", # pyupgrade: non-pep604-annotation-optional +] [tool.coverage.run] omit = ["tests/*"] diff --git a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py index 16c91d419fc..df446b02b47 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py @@ -1,5 +1,7 @@ """Test ChatMistral chat model.""" +from __future__ import annotations + import json import logging import time @@ -43,11 +45,12 @@ async def test_astream() -> None: if token.response_metadata: chunks_with_response_metadata += 1 if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1: - raise AssertionError( + msg = ( "Expected exactly one chunk with token counts or response_metadata. " "AIMessageChunk aggregation adds / appends counts and metadata. Check that " "this is behaving properly." ) + raise AssertionError(msg) assert isinstance(full, AIMessageChunk) assert full.usage_metadata is not None assert full.usage_metadata["input_tokens"] > 0 @@ -61,7 +64,7 @@ async def test_astream() -> None: async def test_abatch() -> None: - """Test streaming tokens from ChatMistralAI""" + """Test streaming tokens from ChatMistralAI.""" llm = ChatMistralAI() result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"]) @@ -70,7 +73,7 @@ async def test_abatch() -> None: async def test_abatch_tags() -> None: - """Test batch tokens from ChatMistralAI""" + """Test batch tokens from ChatMistralAI.""" llm = ChatMistralAI() result = await llm.abatch( @@ -81,7 +84,7 @@ async def test_abatch_tags() -> None: def test_batch() -> None: - """Test batch tokens from ChatMistralAI""" + """Test batch tokens from ChatMistralAI.""" llm = ChatMistralAI() result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"]) @@ -90,7 +93,7 @@ def test_batch() -> None: async def test_ainvoke() -> None: - """Test invoke tokens from ChatMistralAI""" + """Test invoke tokens from ChatMistralAI.""" llm = ChatMistralAI() result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]}) @@ -99,10 +102,10 @@ async def test_ainvoke() -> None: def test_invoke() -> None: - """Test invoke tokens from ChatMistralAI""" + """Test invoke tokens from ChatMistralAI.""" llm = ChatMistralAI() - result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) + result = llm.invoke("I'm Pickle Rick", config={"tags": ["foo"]}) assert isinstance(result.content, str) @@ -178,13 +181,11 @@ def test_streaming_structured_output() -> None: structured_llm = llm.with_structured_output(Person) strm = structured_llm.stream("Erick, 27 years old") - chunk_num = 0 - for chunk in strm: + for chunk_num, chunk in enumerate(strm): assert chunk_num == 0, "should only have one chunk with model" assert isinstance(chunk, Person) assert chunk.name == "Erick" assert chunk.age == 27 - chunk_num += 1 class Book(BaseModel): @@ -201,7 +202,7 @@ def _check_parsed_result(result: Any, schema: Any) -> None: if schema == Book: assert isinstance(result, Book) else: - assert all(key in ["name", "authors"] for key in result.keys()) + assert all(key in ["name", "authors"] for key in result) @pytest.mark.parametrize("schema", [Book, BookDict, Book.model_json_schema()]) diff --git a/libs/partners/mistralai/tests/integration_tests/test_compile.py b/libs/partners/mistralai/tests/integration_tests/test_compile.py index 33ecccdfa0f..f315e45f521 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_compile.py +++ b/libs/partners/mistralai/tests/integration_tests/test_compile.py @@ -4,4 +4,3 @@ import pytest @pytest.mark.compile def test_placeholder() -> None: """Used for compiling integration tests without running any real tests.""" - pass diff --git a/libs/partners/mistralai/tests/integration_tests/test_embeddings.py b/libs/partners/mistralai/tests/integration_tests/test_embeddings.py index 345c29b3555..299feb6f935 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_embeddings.py +++ b/libs/partners/mistralai/tests/integration_tests/test_embeddings.py @@ -1,4 +1,4 @@ -"""Test MistralAI Embedding""" +"""Test MistralAI Embedding.""" from langchain_mistralai import MistralAIEmbeddings diff --git a/libs/partners/mistralai/tests/integration_tests/test_standard.py b/libs/partners/mistralai/tests/integration_tests/test_standard.py index af51d28f98d..8c716c65281 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_standard.py +++ b/libs/partners/mistralai/tests/integration_tests/test_standard.py @@ -1,4 +1,4 @@ -"""Standard LangChain interface tests""" +"""Standard LangChain interface tests.""" from langchain_core.language_models import BaseChatModel from langchain_tests.integration_tests import ( # type: ignore[import-not-found] diff --git a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py index 22dd3fa3e72..a633b7d2696 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py @@ -84,23 +84,23 @@ def test_mistralai_initialization_baseurl_env(env_var_name: str) -> None: [ ( SystemMessage(content="Hello"), - dict(role="system", content="Hello"), + {"role": "system", "content": "Hello"}, ), ( HumanMessage(content="Hello"), - dict(role="user", content="Hello"), + {"role": "user", "content": "Hello"}, ), ( AIMessage(content="Hello"), - dict(role="assistant", content="Hello"), + {"role": "assistant", "content": "Hello"}, ), ( AIMessage(content="{", additional_kwargs={"prefix": True}), - dict(role="assistant", content="{", prefix=True), + {"role": "assistant", "content": "{", "prefix": True}, ), ( ChatMessage(role="assistant", content="Hello"), - dict(role="assistant", content="Hello"), + {"role": "assistant", "content": "Hello"}, ), ], ) @@ -112,17 +112,17 @@ def test_convert_message_to_mistral_chat_message( def _make_completion_response_from_token(token: str) -> dict: - return dict( - id="abc123", - model="fake_model", - choices=[ - dict( - index=0, - delta=dict(content=token), - finish_reason=None, - ) + return { + "id": "abc123", + "model": "fake_model", + "choices": [ + { + "index": 0, + "delta": {"content": token}, + "finish_reason": None, + } ], - ) + } def mock_chat_stream(*args: Any, **kwargs: Any) -> Generator: @@ -275,8 +275,7 @@ def test_extra_kwargs() -> None: def test_retry_with_failure_then_success() -> None: - """Test that retry mechanism works correctly when - first request fails and second succeeds.""" + """Test retry mechanism works correctly when fiest request fails, second succeed.""" # Create a real ChatMistralAI instance chat = ChatMistralAI(max_retries=3) @@ -289,7 +288,8 @@ def test_retry_with_failure_then_success() -> None: call_count += 1 if call_count == 1: - raise httpx.RequestError("Connection error", request=MagicMock()) + msg = "Connection error" + raise httpx.RequestError(msg, request=MagicMock()) mock_response = MagicMock() mock_response.status_code = 200 diff --git a/libs/partners/mistralai/tests/unit_tests/test_standard.py b/libs/partners/mistralai/tests/unit_tests/test_standard.py index 4ba75d610ba..8ff110eaeae 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_standard.py +++ b/libs/partners/mistralai/tests/unit_tests/test_standard.py @@ -1,4 +1,4 @@ -"""Standard LangChain interface tests""" +"""Standard LangChain interface tests.""" from langchain_core.language_models import BaseChatModel from langchain_tests.unit_tests import ( # type: ignore[import-not-found] diff --git a/libs/partners/mistralai/uv.lock b/libs/partners/mistralai/uv.lock index b470a27c980..5caf774d333 100644 --- a/libs/partners/mistralai/uv.lock +++ b/libs/partners/mistralai/uv.lock @@ -379,7 +379,7 @@ dev = [ { name = "jupyter", specifier = ">=1.0.0,<2.0.0" }, { name = "setuptools", specifier = ">=67.6.1,<68.0.0" }, ] -lint = [{ name = "ruff", specifier = ">=0.11.2,<0.12.0" }] +lint = [{ name = "ruff", specifier = ">=0.12.2,<0.13" }] test = [ { name = "blockbuster", specifier = "~=1.5.18" }, { name = "freezegun", specifier = ">=1.2.2,<2.0.0" }, @@ -452,7 +452,7 @@ requires-dist = [ [package.metadata.requires-dev] codespell = [{ name = "codespell", specifier = ">=2.2.0,<3.0.0" }] dev = [{ name = "langchain-core", editable = "../../core" }] -lint = [{ name = "ruff", specifier = ">=0.5,<1.0" }] +lint = [{ name = "ruff", specifier = ">=0.12.2,<0.13" }] test = [ { name = "langchain-core", editable = "../../core" }, { name = "langchain-tests", editable = "../../standard-tests" }, @@ -1370,27 +1370,27 @@ wheels = [ [[package]] name = "ruff" -version = "0.9.4" +version = "0.12.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/c0/17/529e78f49fc6f8076f50d985edd9a2cf011d1dbadb1cdeacc1d12afc1d26/ruff-0.9.4.tar.gz", hash = "sha256:6907ee3529244bb0ed066683e075f09285b38dd5b4039370df6ff06041ca19e7", size = 3599458, upload-time = "2025-01-30T18:09:51.03Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6c/3d/d9a195676f25d00dbfcf3cf95fdd4c685c497fcfa7e862a44ac5e4e96480/ruff-0.12.2.tar.gz", hash = "sha256:d7b4f55cd6f325cb7621244f19c873c565a08aff5a4ba9c69aa7355f3f7afd3e", size = 4432239, upload-time = "2025-07-03T16:40:19.566Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b6/f8/3fafb7804d82e0699a122101b5bee5f0d6e17c3a806dcbc527bb7d3f5b7a/ruff-0.9.4-py3-none-linux_armv6l.whl", hash = "sha256:64e73d25b954f71ff100bb70f39f1ee09e880728efb4250c632ceed4e4cdf706", size = 11668400, upload-time = "2025-01-30T18:08:46.508Z" }, - { url = "https://files.pythonhosted.org/packages/2e/a6/2efa772d335da48a70ab2c6bb41a096c8517ca43c086ea672d51079e3d1f/ruff-0.9.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:6ce6743ed64d9afab4fafeaea70d3631b4d4b28b592db21a5c2d1f0ef52934bf", size = 11628395, upload-time = "2025-01-30T18:08:50.87Z" }, - { url = "https://files.pythonhosted.org/packages/dc/d7/cd822437561082f1c9d7225cc0d0fbb4bad117ad7ac3c41cd5d7f0fa948c/ruff-0.9.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:54499fb08408e32b57360f6f9de7157a5fec24ad79cb3f42ef2c3f3f728dfe2b", size = 11090052, upload-time = "2025-01-30T18:08:54.498Z" }, - { url = "https://files.pythonhosted.org/packages/9e/67/3660d58e893d470abb9a13f679223368ff1684a4ef40f254a0157f51b448/ruff-0.9.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:37c892540108314a6f01f105040b5106aeb829fa5fb0561d2dcaf71485021137", size = 11882221, upload-time = "2025-01-30T18:08:57.784Z" }, - { url = "https://files.pythonhosted.org/packages/79/d1/757559995c8ba5f14dfec4459ef2dd3fcea82ac43bc4e7c7bf47484180c0/ruff-0.9.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:de9edf2ce4b9ddf43fd93e20ef635a900e25f622f87ed6e3047a664d0e8f810e", size = 11424862, upload-time = "2025-01-30T18:09:01.167Z" }, - { url = "https://files.pythonhosted.org/packages/c0/96/7915a7c6877bb734caa6a2af424045baf6419f685632469643dbd8eb2958/ruff-0.9.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:87c90c32357c74f11deb7fbb065126d91771b207bf9bfaaee01277ca59b574ec", size = 12626735, upload-time = "2025-01-30T18:09:05.312Z" }, - { url = "https://files.pythonhosted.org/packages/0e/cc/dadb9b35473d7cb17c7ffe4737b4377aeec519a446ee8514123ff4a26091/ruff-0.9.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:56acd6c694da3695a7461cc55775f3a409c3815ac467279dfa126061d84b314b", size = 13255976, upload-time = "2025-01-30T18:09:09.425Z" }, - { url = "https://files.pythonhosted.org/packages/5f/c3/ad2dd59d3cabbc12df308cced780f9c14367f0321e7800ca0fe52849da4c/ruff-0.9.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e0c93e7d47ed951b9394cf352d6695b31498e68fd5782d6cbc282425655f687a", size = 12752262, upload-time = "2025-01-30T18:09:13.112Z" }, - { url = "https://files.pythonhosted.org/packages/c7/17/5f1971e54bd71604da6788efd84d66d789362b1105e17e5ccc53bba0289b/ruff-0.9.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1d4c8772670aecf037d1bf7a07c39106574d143b26cfe5ed1787d2f31e800214", size = 14401648, upload-time = "2025-01-30T18:09:17.086Z" }, - { url = "https://files.pythonhosted.org/packages/30/24/6200b13ea611b83260501b6955b764bb320e23b2b75884c60ee7d3f0b68e/ruff-0.9.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfc5f1d7afeda8d5d37660eeca6d389b142d7f2b5a1ab659d9214ebd0e025231", size = 12414702, upload-time = "2025-01-30T18:09:21.672Z" }, - { url = "https://files.pythonhosted.org/packages/34/cb/f5d50d0c4ecdcc7670e348bd0b11878154bc4617f3fdd1e8ad5297c0d0ba/ruff-0.9.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:faa935fc00ae854d8b638c16a5f1ce881bc3f67446957dd6f2af440a5fc8526b", size = 11859608, upload-time = "2025-01-30T18:09:25.663Z" }, - { url = "https://files.pythonhosted.org/packages/d6/f4/9c8499ae8426da48363bbb78d081b817b0f64a9305f9b7f87eab2a8fb2c1/ruff-0.9.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:a6c634fc6f5a0ceae1ab3e13c58183978185d131a29c425e4eaa9f40afe1e6d6", size = 11485702, upload-time = "2025-01-30T18:09:28.903Z" }, - { url = "https://files.pythonhosted.org/packages/18/59/30490e483e804ccaa8147dd78c52e44ff96e1c30b5a95d69a63163cdb15b/ruff-0.9.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:433dedf6ddfdec7f1ac7575ec1eb9844fa60c4c8c2f8887a070672b8d353d34c", size = 12067782, upload-time = "2025-01-30T18:09:32.371Z" }, - { url = "https://files.pythonhosted.org/packages/3d/8c/893fa9551760b2f8eb2a351b603e96f15af167ceaf27e27ad873570bc04c/ruff-0.9.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:d612dbd0f3a919a8cc1d12037168bfa536862066808960e0cc901404b77968f0", size = 12483087, upload-time = "2025-01-30T18:09:36.124Z" }, - { url = "https://files.pythonhosted.org/packages/23/15/f6751c07c21ca10e3f4a51ea495ca975ad936d780c347d9808bcedbd7182/ruff-0.9.4-py3-none-win32.whl", hash = "sha256:db1192ddda2200671f9ef61d9597fcef89d934f5d1705e571a93a67fb13a4402", size = 9852302, upload-time = "2025-01-30T18:09:40.013Z" }, - { url = "https://files.pythonhosted.org/packages/12/41/2d2d2c6a72e62566f730e49254f602dfed23019c33b5b21ea8f8917315a1/ruff-0.9.4-py3-none-win_amd64.whl", hash = "sha256:05bebf4cdbe3ef75430d26c375773978950bbf4ee3c95ccb5448940dc092408e", size = 10850051, upload-time = "2025-01-30T18:09:43.42Z" }, - { url = "https://files.pythonhosted.org/packages/c6/e6/3d6ec3bc3d254e7f005c543a661a41c3e788976d0e52a1ada195bd664344/ruff-0.9.4-py3-none-win_arm64.whl", hash = "sha256:585792f1e81509e38ac5123492f8875fbc36f3ede8185af0a26df348e5154f41", size = 10078251, upload-time = "2025-01-30T18:09:48.01Z" }, + { url = "https://files.pythonhosted.org/packages/74/b6/2098d0126d2d3318fd5bec3ad40d06c25d377d95749f7a0c5af17129b3b1/ruff-0.12.2-py3-none-linux_armv6l.whl", hash = "sha256:093ea2b221df1d2b8e7ad92fc6ffdca40a2cb10d8564477a987b44fd4008a7be", size = 10369761, upload-time = "2025-07-03T16:39:38.847Z" }, + { url = "https://files.pythonhosted.org/packages/b1/4b/5da0142033dbe155dc598cfb99262d8ee2449d76920ea92c4eeb9547c208/ruff-0.12.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:09e4cf27cc10f96b1708100fa851e0daf21767e9709e1649175355280e0d950e", size = 11155659, upload-time = "2025-07-03T16:39:42.294Z" }, + { url = "https://files.pythonhosted.org/packages/3e/21/967b82550a503d7c5c5c127d11c935344b35e8c521f52915fc858fb3e473/ruff-0.12.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:8ae64755b22f4ff85e9c52d1f82644abd0b6b6b6deedceb74bd71f35c24044cc", size = 10537769, upload-time = "2025-07-03T16:39:44.75Z" }, + { url = "https://files.pythonhosted.org/packages/33/91/00cff7102e2ec71a4890fb7ba1803f2cdb122d82787c7d7cf8041fe8cbc1/ruff-0.12.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3eb3a6b2db4d6e2c77e682f0b988d4d61aff06860158fdb413118ca133d57922", size = 10717602, upload-time = "2025-07-03T16:39:47.652Z" }, + { url = "https://files.pythonhosted.org/packages/9b/eb/928814daec4e1ba9115858adcda44a637fb9010618721937491e4e2283b8/ruff-0.12.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:73448de992d05517170fc37169cbca857dfeaeaa8c2b9be494d7bcb0d36c8f4b", size = 10198772, upload-time = "2025-07-03T16:39:49.641Z" }, + { url = "https://files.pythonhosted.org/packages/50/fa/f15089bc20c40f4f72334f9145dde55ab2b680e51afb3b55422effbf2fb6/ruff-0.12.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3b8b94317cbc2ae4a2771af641739f933934b03555e51515e6e021c64441532d", size = 11845173, upload-time = "2025-07-03T16:39:52.069Z" }, + { url = "https://files.pythonhosted.org/packages/43/9f/1f6f98f39f2b9302acc161a4a2187b1e3a97634fe918a8e731e591841cf4/ruff-0.12.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:45fc42c3bf1d30d2008023a0a9a0cfb06bf9835b147f11fe0679f21ae86d34b1", size = 12553002, upload-time = "2025-07-03T16:39:54.551Z" }, + { url = "https://files.pythonhosted.org/packages/d8/70/08991ac46e38ddd231c8f4fd05ef189b1b94be8883e8c0c146a025c20a19/ruff-0.12.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ce48f675c394c37e958bf229fb5c1e843e20945a6d962cf3ea20b7a107dcd9f4", size = 12171330, upload-time = "2025-07-03T16:39:57.55Z" }, + { url = "https://files.pythonhosted.org/packages/88/a9/5a55266fec474acfd0a1c73285f19dd22461d95a538f29bba02edd07a5d9/ruff-0.12.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:793d8859445ea47591272021a81391350205a4af65a9392401f418a95dfb75c9", size = 11774717, upload-time = "2025-07-03T16:39:59.78Z" }, + { url = "https://files.pythonhosted.org/packages/87/e5/0c270e458fc73c46c0d0f7cf970bb14786e5fdb88c87b5e423a4bd65232b/ruff-0.12.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6932323db80484dda89153da3d8e58164d01d6da86857c79f1961934354992da", size = 11646659, upload-time = "2025-07-03T16:40:01.934Z" }, + { url = "https://files.pythonhosted.org/packages/b7/b6/45ab96070c9752af37f0be364d849ed70e9ccede07675b0ec4e3ef76b63b/ruff-0.12.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:6aa7e623a3a11538108f61e859ebf016c4f14a7e6e4eba1980190cacb57714ce", size = 10604012, upload-time = "2025-07-03T16:40:04.363Z" }, + { url = "https://files.pythonhosted.org/packages/86/91/26a6e6a424eb147cc7627eebae095cfa0b4b337a7c1c413c447c9ebb72fd/ruff-0.12.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2a4a20aeed74671b2def096bdf2eac610c7d8ffcbf4fb0e627c06947a1d7078d", size = 10176799, upload-time = "2025-07-03T16:40:06.514Z" }, + { url = "https://files.pythonhosted.org/packages/f5/0c/9f344583465a61c8918a7cda604226e77b2c548daf8ef7c2bfccf2b37200/ruff-0.12.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:71a4c550195612f486c9d1f2b045a600aeba851b298c667807ae933478fcef04", size = 11241507, upload-time = "2025-07-03T16:40:08.708Z" }, + { url = "https://files.pythonhosted.org/packages/1c/b7/99c34ded8fb5f86c0280278fa89a0066c3760edc326e935ce0b1550d315d/ruff-0.12.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:4987b8f4ceadf597c927beee65a5eaf994c6e2b631df963f86d8ad1bdea99342", size = 11717609, upload-time = "2025-07-03T16:40:10.836Z" }, + { url = "https://files.pythonhosted.org/packages/51/de/8589fa724590faa057e5a6d171e7f2f6cffe3287406ef40e49c682c07d89/ruff-0.12.2-py3-none-win32.whl", hash = "sha256:369ffb69b70cd55b6c3fc453b9492d98aed98062db9fec828cdfd069555f5f1a", size = 10523823, upload-time = "2025-07-03T16:40:13.203Z" }, + { url = "https://files.pythonhosted.org/packages/94/47/8abf129102ae4c90cba0c2199a1a9b0fa896f6f806238d6f8c14448cc748/ruff-0.12.2-py3-none-win_amd64.whl", hash = "sha256:dca8a3b6d6dc9810ed8f328d406516bf4d660c00caeaef36eb831cf4871b0639", size = 11629831, upload-time = "2025-07-03T16:40:15.478Z" }, + { url = "https://files.pythonhosted.org/packages/e2/1f/72d2946e3cc7456bb837e88000eb3437e55f80db339c840c04015a11115d/ruff-0.12.2-py3-none-win_arm64.whl", hash = "sha256:48d6c6bfb4761df68bc05ae630e24f506755e702d4fb08f08460be778c7ccb12", size = 10735334, upload-time = "2025-07-03T16:40:17.677Z" }, ] [[package]]