Fix: Enable max_retries Parameter in ChatMistralAI Class (#30448)

**partners: Enable max_retries in ChatMistralAI**

**Description**

- This pull request reactivates the retry logic in the
completion_with_retry method of the ChatMistralAI class, restoring the
intended functionality of the previously ineffective max_retries
parameter. New unit test that mocks failed/successful retry calls and an
integration test to confirm end-to-end functionality.

**Issue**
- Closes #30362

**Dependencies**
- No additional dependencies required

Co-authored-by: andrasfe <andrasf94@gmail.com>
This commit is contained in:
Andras L Ferenczi
2025-03-27 11:53:44 -04:00
committed by GitHub
parent 3aa080c2a8
commit 63673b765b
3 changed files with 86 additions and 3 deletions

View File

@@ -2,8 +2,9 @@
import os
from typing import Any, AsyncGenerator, Dict, Generator, List, cast
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import httpx
import pytest
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.messages import (
@@ -270,3 +271,46 @@ def test_extra_kwargs() -> None:
# Test that if provided twice it errors
with pytest.raises(ValueError):
ChatMistralAI(model="my-model", foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]
def test_retry_with_failure_then_success() -> None:
"""Test that retry mechanism works correctly when
first request fails and second succeeds."""
# Create a real ChatMistralAI instance
chat = ChatMistralAI(max_retries=3)
# Set up the actual retry mechanism (not just mocking it)
# We'll track how many times the function is called
call_count = 0
def mock_post(*args: Any, **kwargs: Any) -> MagicMock:
nonlocal call_count
call_count += 1
if call_count == 1:
raise httpx.RequestError("Connection error", request=MagicMock())
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [
{
"message": {
"role": "assistant",
"content": "Hello!",
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
},
}
return mock_response
with patch.object(chat.client, "post", side_effect=mock_post):
result = chat.invoke("Hello")
assert result.content == "Hello!"
assert call_count == 2, f"Expected 2 calls, but got {call_count}"