mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
fix(mistralai): update exception retry logic in embeddings (#35238)
This commit is contained in:
@@ -16,7 +16,7 @@ from pydantic import (
|
||||
SecretStr,
|
||||
model_validator,
|
||||
)
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
|
||||
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_fixed
|
||||
from tokenizers import Tokenizer # type: ignore[import]
|
||||
from typing_extensions import Self
|
||||
|
||||
@@ -29,6 +29,25 @@ of tokens that can be sent in a single request to the Mistral API (across multip
|
||||
documents/chunks)"""
|
||||
|
||||
|
||||
def _is_retryable_error(exception: BaseException) -> bool:
|
||||
"""Determine if an exception should trigger a retry.
|
||||
|
||||
Only retries on:
|
||||
- Timeout exceptions
|
||||
- 429 (rate limit) errors
|
||||
- 5xx (server) errors
|
||||
|
||||
Does NOT retry on 400 (bad request) or other 4xx client errors.
|
||||
"""
|
||||
if isinstance(exception, httpx.TimeoutException):
|
||||
return True
|
||||
if isinstance(exception, httpx.HTTPStatusError):
|
||||
status_code = exception.response.status_code
|
||||
# Retry on rate limit (429) or server errors (5xx)
|
||||
return status_code == 429 or status_code >= 500
|
||||
return False
|
||||
|
||||
|
||||
class DummyTokenizer:
|
||||
"""Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface)."""
|
||||
|
||||
@@ -225,9 +244,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
return func
|
||||
|
||||
return retry(
|
||||
retry=retry_if_exception_type(
|
||||
(httpx.TimeoutException, httpx.HTTPStatusError)
|
||||
),
|
||||
retry=retry_if_exception(_is_retryable_error),
|
||||
wait=wait_fixed(self.wait_time),
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
)(func)
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
import os
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
from pydantic import SecretStr
|
||||
|
||||
from langchain_mistralai import MistralAIEmbeddings
|
||||
from langchain_mistralai.embeddings import (
|
||||
DummyTokenizer,
|
||||
_is_retryable_error,
|
||||
)
|
||||
|
||||
os.environ["MISTRAL_API_KEY"] = "foo"
|
||||
|
||||
@@ -15,3 +21,60 @@ def test_mistral_init() -> None:
|
||||
]:
|
||||
assert model.model == "mistral-embed"
|
||||
assert cast("SecretStr", model.mistral_api_key).get_secret_value() == "test"
|
||||
|
||||
|
||||
def test_is_retryable_error_timeout() -> None:
|
||||
"""Test that timeout exceptions are retryable."""
|
||||
exc = httpx.TimeoutException("timeout")
|
||||
assert _is_retryable_error(exc) is True
|
||||
|
||||
|
||||
def test_is_retryable_error_rate_limit() -> None:
|
||||
"""Test that 429 errors are retryable."""
|
||||
response = MagicMock()
|
||||
response.status_code = 429
|
||||
exc = httpx.HTTPStatusError("rate limit", request=MagicMock(), response=response)
|
||||
assert _is_retryable_error(exc) is True
|
||||
|
||||
|
||||
def test_is_retryable_error_server_error() -> None:
|
||||
"""Test that 5xx errors are retryable."""
|
||||
for status_code in [500, 502, 503, 504]:
|
||||
response = MagicMock()
|
||||
response.status_code = status_code
|
||||
exc = httpx.HTTPStatusError(
|
||||
"server error", request=MagicMock(), response=response
|
||||
)
|
||||
assert _is_retryable_error(exc) is True
|
||||
|
||||
|
||||
def test_is_retryable_error_bad_request_not_retryable() -> None:
|
||||
"""Test that 400 errors are NOT retryable."""
|
||||
response = MagicMock()
|
||||
response.status_code = 400
|
||||
exc = httpx.HTTPStatusError("bad request", request=MagicMock(), response=response)
|
||||
assert _is_retryable_error(exc) is False
|
||||
|
||||
|
||||
def test_is_retryable_error_other_4xx_not_retryable() -> None:
|
||||
"""Test that other 4xx errors are NOT retryable."""
|
||||
for status_code in [401, 403, 404, 422]:
|
||||
response = MagicMock()
|
||||
response.status_code = status_code
|
||||
exc = httpx.HTTPStatusError(
|
||||
"client error", request=MagicMock(), response=response
|
||||
)
|
||||
assert _is_retryable_error(exc) is False
|
||||
|
||||
|
||||
def test_is_retryable_error_other_exceptions() -> None:
|
||||
"""Test that other exceptions are not retryable."""
|
||||
assert _is_retryable_error(ValueError("test")) is False
|
||||
assert _is_retryable_error(RuntimeError("test")) is False
|
||||
|
||||
|
||||
def test_dummy_tokenizer() -> None:
|
||||
"""Test that DummyTokenizer returns character lists."""
|
||||
tokenizer = DummyTokenizer()
|
||||
result = tokenizer.encode_batch(["hello", "world"])
|
||||
assert result == [["h", "e", "l", "l", "o"], ["w", "o", "r", "l", "d"]]
|
||||
|
||||
Reference in New Issue
Block a user