From 63673b765b7deffcd116ba8f3833a7ff8a2ac8d2 Mon Sep 17 00:00:00 2001 From: Andras L Ferenczi Date: Thu, 27 Mar 2025 11:53:44 -0400 Subject: [PATCH] Fix: Enable max_retries Parameter in ChatMistralAI Class (#30448) **partners: Enable max_retries in ChatMistralAI** **Description** - This pull request reactivates the retry logic in the completion_with_retry method of the ChatMistralAI class, restoring the intended functionality of the previously ineffective max_retries parameter. New unit test that mocks failed/successful retry calls and an integration test to confirm end-to-end functionality. **Issue** - Closes #30362 **Dependencies** - No additional dependencies required Co-authored-by: andrasfe --- .../langchain_mistralai/chat_models.py | 4 +- .../integration_tests/test_chat_models.py | 39 ++++++++++++++++ .../tests/unit_tests/test_chat_models.py | 46 ++++++++++++++++++- 3 files changed, 86 insertions(+), 3 deletions(-) diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 6f3cac19904..bc761d2bb6c 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -464,9 +464,9 @@ class ChatMistralAI(BaseChatModel): self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any ) -> Any: """Use tenacity to retry the completion call.""" - # retry_decorator = _create_retry_decorator(self, run_manager=run_manager) + retry_decorator = _create_retry_decorator(self, run_manager=run_manager) - # @retry_decorator + @retry_decorator def _completion_with_retry(**kwargs: Any) -> Any: if "stream" not in kwargs: kwargs["stream"] = False diff --git a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py index 8bec346d29a..16c91d419fc 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py @@ -1,9 +1,12 @@ """Test ChatMistral chat model.""" import json +import logging +import time from typing import Any, Optional import pytest +from httpx import ReadTimeout from langchain_core.messages import ( AIMessage, AIMessageChunk, @@ -301,3 +304,39 @@ def test_streaming_tool_call() -> None: acc = chunk if acc is None else acc + chunk assert acc.content != "" assert "tool_calls" not in acc.additional_kwargs + + +def test_retry_parameters(caplog: pytest.LogCaptureFixture) -> None: + """Test that retry parameters are honored in ChatMistralAI.""" + # Create a model with intentionally short timeout and multiple retries + mistral = ChatMistralAI( + timeout=1, # Very short timeout to trigger timeouts + max_retries=3, # Should retry 3 times + ) + + # Simple test input that should take longer than 1 second to process + test_input = "Write a 2 sentence story about a cat" + + # Measure start time + t0 = time.time() + + try: + # Try to get a response + response = mistral.invoke(test_input) + + # If successful, validate the response + elapsed_time = time.time() - t0 + logging.info(f"Request succeeded in {elapsed_time:.2f} seconds") + # Check that we got a valid response + assert response.content + assert isinstance(response.content, str) + assert "cat" in response.content.lower() + + except ReadTimeout: + elapsed_time = time.time() - t0 + logging.info(f"Request timed out after {elapsed_time:.2f} seconds") + assert elapsed_time >= 3.0 + pytest.skip("Test timed out as expected with short timeout") + except Exception as e: + logging.error(f"Unexpected exception: {e}") + raise diff --git a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py index 4dc251832e7..6a94f431cfd 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py @@ -2,8 +2,9 @@ import os from typing import Any, AsyncGenerator, Dict, Generator, List, cast -from unittest.mock import patch +from unittest.mock import MagicMock, patch +import httpx import pytest from langchain_core.callbacks.base import BaseCallbackHandler from langchain_core.messages import ( @@ -270,3 +271,46 @@ def test_extra_kwargs() -> None: # Test that if provided twice it errors with pytest.raises(ValueError): ChatMistralAI(model="my-model", foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg] + + +def test_retry_with_failure_then_success() -> None: + """Test that retry mechanism works correctly when + first request fails and second succeeds.""" + # Create a real ChatMistralAI instance + chat = ChatMistralAI(max_retries=3) + + # Set up the actual retry mechanism (not just mocking it) + # We'll track how many times the function is called + call_count = 0 + + def mock_post(*args: Any, **kwargs: Any) -> MagicMock: + nonlocal call_count + call_count += 1 + + if call_count == 1: + raise httpx.RequestError("Connection error", request=MagicMock()) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [ + { + "message": { + "role": "assistant", + "content": "Hello!", + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 1, + "total_tokens": 2, + }, + } + return mock_response + + with patch.object(chat.client, "post", side_effect=mock_post): + result = chat.invoke("Hello") + assert result.content == "Hello!" + assert call_count == 2, f"Expected 2 calls, but got {call_count}"