Fix: Enable max_retries Parameter in ChatMistralAI Class (#30448)

**partners: Enable max_retries in ChatMistralAI**

**Description**

- This pull request reactivates the retry logic in the
completion_with_retry method of the ChatMistralAI class, restoring the
intended functionality of the previously ineffective max_retries
parameter. New unit test that mocks failed/successful retry calls and an
integration test to confirm end-to-end functionality.

**Issue**
- Closes #30362

**Dependencies**
- No additional dependencies required

Co-authored-by: andrasfe <andrasf94@gmail.com>
This commit is contained in:
Andras L Ferenczi 2025-03-27 11:53:44 -04:00 committed by GitHub
parent 3aa080c2a8
commit 63673b765b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 86 additions and 3 deletions

View File

@ -464,9 +464,9 @@ class ChatMistralAI(BaseChatModel):
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
) -> Any: ) -> Any:
"""Use tenacity to retry the completion call.""" """Use tenacity to retry the completion call."""
# retry_decorator = _create_retry_decorator(self, run_manager=run_manager) retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
# @retry_decorator @retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any: def _completion_with_retry(**kwargs: Any) -> Any:
if "stream" not in kwargs: if "stream" not in kwargs:
kwargs["stream"] = False kwargs["stream"] = False

View File

@ -1,9 +1,12 @@
"""Test ChatMistral chat model.""" """Test ChatMistral chat model."""
import json import json
import logging
import time
from typing import Any, Optional from typing import Any, Optional
import pytest import pytest
from httpx import ReadTimeout
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
AIMessageChunk, AIMessageChunk,
@ -301,3 +304,39 @@ def test_streaming_tool_call() -> None:
acc = chunk if acc is None else acc + chunk acc = chunk if acc is None else acc + chunk
assert acc.content != "" assert acc.content != ""
assert "tool_calls" not in acc.additional_kwargs assert "tool_calls" not in acc.additional_kwargs
def test_retry_parameters(caplog: pytest.LogCaptureFixture) -> None:
"""Test that retry parameters are honored in ChatMistralAI."""
# Create a model with intentionally short timeout and multiple retries
mistral = ChatMistralAI(
timeout=1, # Very short timeout to trigger timeouts
max_retries=3, # Should retry 3 times
)
# Simple test input that should take longer than 1 second to process
test_input = "Write a 2 sentence story about a cat"
# Measure start time
t0 = time.time()
try:
# Try to get a response
response = mistral.invoke(test_input)
# If successful, validate the response
elapsed_time = time.time() - t0
logging.info(f"Request succeeded in {elapsed_time:.2f} seconds")
# Check that we got a valid response
assert response.content
assert isinstance(response.content, str)
assert "cat" in response.content.lower()
except ReadTimeout:
elapsed_time = time.time() - t0
logging.info(f"Request timed out after {elapsed_time:.2f} seconds")
assert elapsed_time >= 3.0
pytest.skip("Test timed out as expected with short timeout")
except Exception as e:
logging.error(f"Unexpected exception: {e}")
raise

View File

@ -2,8 +2,9 @@
import os import os
from typing import Any, AsyncGenerator, Dict, Generator, List, cast from typing import Any, AsyncGenerator, Dict, Generator, List, cast
from unittest.mock import patch from unittest.mock import MagicMock, patch
import httpx
import pytest import pytest
from langchain_core.callbacks.base import BaseCallbackHandler from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.messages import ( from langchain_core.messages import (
@ -270,3 +271,46 @@ def test_extra_kwargs() -> None:
# Test that if provided twice it errors # Test that if provided twice it errors
with pytest.raises(ValueError): with pytest.raises(ValueError):
ChatMistralAI(model="my-model", foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg] ChatMistralAI(model="my-model", foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]
def test_retry_with_failure_then_success() -> None:
"""Test that retry mechanism works correctly when
first request fails and second succeeds."""
# Create a real ChatMistralAI instance
chat = ChatMistralAI(max_retries=3)
# Set up the actual retry mechanism (not just mocking it)
# We'll track how many times the function is called
call_count = 0
def mock_post(*args: Any, **kwargs: Any) -> MagicMock:
nonlocal call_count
call_count += 1
if call_count == 1:
raise httpx.RequestError("Connection error", request=MagicMock())
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [
{
"message": {
"role": "assistant",
"content": "Hello!",
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
},
}
return mock_response
with patch.object(chat.client, "post", side_effect=mock_post):
result = chat.invoke("Hello")
assert result.content == "Hello!"
assert call_count == 2, f"Expected 2 calls, but got {call_count}"