Fix: Enable max_retries Parameter in ChatMistralAI Class (#30448)

**partners: Enable max_retries in ChatMistralAI**

**Description**

- This pull request reactivates the retry logic in the
completion_with_retry method of the ChatMistralAI class, restoring the
intended functionality of the previously ineffective max_retries
parameter. New unit test that mocks failed/successful retry calls and an
integration test to confirm end-to-end functionality.

**Issue**
- Closes #30362

**Dependencies**
- No additional dependencies required

Co-authored-by: andrasfe <andrasf94@gmail.com>
This commit is contained in:
Andras L Ferenczi
2025-03-27 11:53:44 -04:00
committed by GitHub
parent 3aa080c2a8
commit 63673b765b
3 changed files with 86 additions and 3 deletions

View File

@@ -1,9 +1,12 @@
"""Test ChatMistral chat model."""
import json
import logging
import time
from typing import Any, Optional
import pytest
from httpx import ReadTimeout
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
@@ -301,3 +304,39 @@ def test_streaming_tool_call() -> None:
acc = chunk if acc is None else acc + chunk
assert acc.content != ""
assert "tool_calls" not in acc.additional_kwargs
def test_retry_parameters(caplog: pytest.LogCaptureFixture) -> None:
"""Test that retry parameters are honored in ChatMistralAI."""
# Create a model with intentionally short timeout and multiple retries
mistral = ChatMistralAI(
timeout=1, # Very short timeout to trigger timeouts
max_retries=3, # Should retry 3 times
)
# Simple test input that should take longer than 1 second to process
test_input = "Write a 2 sentence story about a cat"
# Measure start time
t0 = time.time()
try:
# Try to get a response
response = mistral.invoke(test_input)
# If successful, validate the response
elapsed_time = time.time() - t0
logging.info(f"Request succeeded in {elapsed_time:.2f} seconds")
# Check that we got a valid response
assert response.content
assert isinstance(response.content, str)
assert "cat" in response.content.lower()
except ReadTimeout:
elapsed_time = time.time() - t0
logging.info(f"Request timed out after {elapsed_time:.2f} seconds")
assert elapsed_time >= 3.0
pytest.skip("Test timed out as expected with short timeout")
except Exception as e:
logging.error(f"Unexpected exception: {e}")
raise