feat(mistralai): remove tenacity retries for embeddings (#33491)

This commit is contained in:
noeliecherrier
2025-10-21 18:35:10 +02:00
committed by GitHub
parent 2222470f69
commit 9f470d297f
2 changed files with 34 additions and 18 deletions

View File

@@ -1,7 +1,7 @@
import asyncio import asyncio
import logging import logging
import warnings import warnings
from collections.abc import Iterable from collections.abc import Callable, Iterable
import httpx import httpx
from httpx import Response from httpx import Response
@@ -57,6 +57,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
api_key: api_key:
The API key for the MistralAI API. If not provided, it will be read from the The API key for the MistralAI API. If not provided, it will be read from the
environment variable `MISTRAL_API_KEY`. environment variable `MISTRAL_API_KEY`.
max_concurrent_requests: int
max_retries: max_retries:
The number of times to retry a request if it fails. The number of times to retry a request if it fails.
timeout: timeout:
@@ -133,9 +134,9 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
default_factory=secret_from_env("MISTRAL_API_KEY", default=""), default_factory=secret_from_env("MISTRAL_API_KEY", default=""),
) )
endpoint: str = "https://api.mistral.ai/v1/" endpoint: str = "https://api.mistral.ai/v1/"
max_retries: int = 5 max_retries: int | None = 5
timeout: int = 120 timeout: int = 120
wait_time: int = 30 wait_time: int | None = 30
max_concurrent_requests: int = 64 max_concurrent_requests: int = 64
tokenizer: Tokenizer = Field(default=None) tokenizer: Tokenizer = Field(default=None)
@@ -212,6 +213,18 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
if batch: if batch:
yield batch yield batch
def _retry(self, func: Callable) -> Callable:
if self.max_retries is None or self.wait_time is None:
return func
return retry(
retry=retry_if_exception_type(
(httpx.TimeoutException, httpx.HTTPStatusError)
),
wait=wait_fixed(self.wait_time),
stop=stop_after_attempt(self.max_retries),
)(func)
def embed_documents(self, texts: list[str]) -> list[list[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed a list of document texts. """Embed a list of document texts.
@@ -225,13 +238,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
try: try:
batch_responses = [] batch_responses = []
@retry( @self._retry
retry=retry_if_exception_type(
(httpx.TimeoutException, httpx.HTTPStatusError)
),
wait=wait_fixed(self.wait_time),
stop=stop_after_attempt(self.max_retries),
)
def _embed_batch(batch: list[str]) -> Response: def _embed_batch(batch: list[str]) -> Response:
response = self.client.post( response = self.client.post(
url="/embeddings", url="/embeddings",
@@ -266,13 +273,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
""" """
try: try:
@retry( @self._retry
retry=retry_if_exception_type(
(httpx.TimeoutException, httpx.HTTPStatusError)
),
wait=wait_fixed(self.wait_time),
stop=stop_after_attempt(self.max_retries),
)
async def _aembed_batch(batch: list[str]) -> Response: async def _aembed_batch(batch: list[str]) -> Response:
response = await self.async_client.post( response = await self.async_client.post(
url="/embeddings", url="/embeddings",

View File

@@ -35,7 +35,7 @@ async def test_mistralai_embedding_documents_async() -> None:
assert len(output[0]) == 1024 assert len(output[0]) == 1024
async def test_mistralai_embedding_documents_http_error_async() -> None: async def test_mistralai_embedding_documents_tenacity_error_async() -> None:
"""Test MistralAI embeddings for documents.""" """Test MistralAI embeddings for documents."""
documents = ["foo bar", "test document"] documents = ["foo bar", "test document"]
embedding = MistralAIEmbeddings(max_retries=0) embedding = MistralAIEmbeddings(max_retries=0)
@@ -50,6 +50,21 @@ async def test_mistralai_embedding_documents_http_error_async() -> None:
await embedding.aembed_documents(documents) await embedding.aembed_documents(documents)
async def test_mistralai_embedding_documents_http_error_async() -> None:
"""Test MistralAI embeddings for documents."""
documents = ["foo bar", "test document"]
embedding = MistralAIEmbeddings(max_retries=None)
mock_response = httpx.Response(
status_code=400,
request=httpx.Request("POST", url=embedding.async_client.base_url),
)
with (
patch.object(embedding.async_client, "post", return_value=mock_response),
pytest.raises(httpx.HTTPStatusError),
):
await embedding.aembed_documents(documents)
async def test_mistralai_embedding_query_async() -> None: async def test_mistralai_embedding_query_async() -> None:
"""Test MistralAI embeddings for query.""" """Test MistralAI embeddings for query."""
document = "foo bar" document = "foo bar"