From cdbe6c34f97415bcc182a130e02d18e3f678fd15 Mon Sep 17 00:00:00 2001 From: Ademola Balogun <34436072+ademicho123@users.noreply.github.com> Date: Mon, 16 Feb 2026 02:25:15 +0000 Subject: [PATCH] fix(mistralai): update exception retry logic in embeddings (#35238) --- .../langchain_mistralai/embeddings.py | 25 ++++++-- .../tests/unit_tests/test_embeddings.py | 63 +++++++++++++++++++ 2 files changed, 84 insertions(+), 4 deletions(-) diff --git a/libs/partners/mistralai/langchain_mistralai/embeddings.py b/libs/partners/mistralai/langchain_mistralai/embeddings.py index 7c0ad474c68..30311adac8c 100644 --- a/libs/partners/mistralai/langchain_mistralai/embeddings.py +++ b/libs/partners/mistralai/langchain_mistralai/embeddings.py @@ -16,7 +16,7 @@ from pydantic import ( SecretStr, model_validator, ) -from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed +from tenacity import retry, retry_if_exception, stop_after_attempt, wait_fixed from tokenizers import Tokenizer # type: ignore[import] from typing_extensions import Self @@ -29,6 +29,25 @@ of tokens that can be sent in a single request to the Mistral API (across multip documents/chunks)""" +def _is_retryable_error(exception: BaseException) -> bool: + """Determine if an exception should trigger a retry. + + Only retries on: + - Timeout exceptions + - 429 (rate limit) errors + - 5xx (server) errors + + Does NOT retry on 400 (bad request) or other 4xx client errors. + """ + if isinstance(exception, httpx.TimeoutException): + return True + if isinstance(exception, httpx.HTTPStatusError): + status_code = exception.response.status_code + # Retry on rate limit (429) or server errors (5xx) + return status_code == 429 or status_code >= 500 + return False + + class DummyTokenizer: """Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface).""" @@ -225,9 +244,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings): return func return retry( - retry=retry_if_exception_type( - (httpx.TimeoutException, httpx.HTTPStatusError) - ), + retry=retry_if_exception(_is_retryable_error), wait=wait_fixed(self.wait_time), stop=stop_after_attempt(self.max_retries), )(func) diff --git a/libs/partners/mistralai/tests/unit_tests/test_embeddings.py b/libs/partners/mistralai/tests/unit_tests/test_embeddings.py index 3b4a2472fae..ddd6cad6af8 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_embeddings.py +++ b/libs/partners/mistralai/tests/unit_tests/test_embeddings.py @@ -1,9 +1,15 @@ import os from typing import cast +from unittest.mock import MagicMock +import httpx from pydantic import SecretStr from langchain_mistralai import MistralAIEmbeddings +from langchain_mistralai.embeddings import ( + DummyTokenizer, + _is_retryable_error, +) os.environ["MISTRAL_API_KEY"] = "foo" @@ -15,3 +21,60 @@ def test_mistral_init() -> None: ]: assert model.model == "mistral-embed" assert cast("SecretStr", model.mistral_api_key).get_secret_value() == "test" + + +def test_is_retryable_error_timeout() -> None: + """Test that timeout exceptions are retryable.""" + exc = httpx.TimeoutException("timeout") + assert _is_retryable_error(exc) is True + + +def test_is_retryable_error_rate_limit() -> None: + """Test that 429 errors are retryable.""" + response = MagicMock() + response.status_code = 429 + exc = httpx.HTTPStatusError("rate limit", request=MagicMock(), response=response) + assert _is_retryable_error(exc) is True + + +def test_is_retryable_error_server_error() -> None: + """Test that 5xx errors are retryable.""" + for status_code in [500, 502, 503, 504]: + response = MagicMock() + response.status_code = status_code + exc = httpx.HTTPStatusError( + "server error", request=MagicMock(), response=response + ) + assert _is_retryable_error(exc) is True + + +def test_is_retryable_error_bad_request_not_retryable() -> None: + """Test that 400 errors are NOT retryable.""" + response = MagicMock() + response.status_code = 400 + exc = httpx.HTTPStatusError("bad request", request=MagicMock(), response=response) + assert _is_retryable_error(exc) is False + + +def test_is_retryable_error_other_4xx_not_retryable() -> None: + """Test that other 4xx errors are NOT retryable.""" + for status_code in [401, 403, 404, 422]: + response = MagicMock() + response.status_code = status_code + exc = httpx.HTTPStatusError( + "client error", request=MagicMock(), response=response + ) + assert _is_retryable_error(exc) is False + + +def test_is_retryable_error_other_exceptions() -> None: + """Test that other exceptions are not retryable.""" + assert _is_retryable_error(ValueError("test")) is False + assert _is_retryable_error(RuntimeError("test")) is False + + +def test_dummy_tokenizer() -> None: + """Test that DummyTokenizer returns character lists.""" + tokenizer = DummyTokenizer() + result = tokenizer.encode_batch(["hello", "world"]) + assert result == [["h", "e", "l", "l", "o"], ["w", "o", "r", "l", "d"]]