diff --git a/libs/partners/mistralai/langchain_mistralai/embeddings.py b/libs/partners/mistralai/langchain_mistralai/embeddings.py index ad4f69c0e99..3670d9e3e8e 100644 --- a/libs/partners/mistralai/langchain_mistralai/embeddings.py +++ b/libs/partners/mistralai/langchain_mistralai/embeddings.py @@ -267,20 +267,29 @@ class MistralAIEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. - """ try: + + @retry( + retry=retry_if_exception_type( + (httpx.TimeoutException, httpx.HTTPStatusError) + ), + wait=wait_fixed(self.wait_time), + stop=stop_after_attempt(self.max_retries), + ) + async def _aembed_batch(batch: list[str]) -> Response: + response = await self.async_client.post( + url="/embeddings", + json={ + "model": self.model, + "input": batch, + }, + ) + response.raise_for_status() + return response + batch_responses = await asyncio.gather( - *[ - self.async_client.post( - url="/embeddings", - json={ - "model": self.model, - "input": batch, - }, - ) - for batch in self._get_batches(texts) - ] + *[_aembed_batch(batch) for batch in self._get_batches(texts)] ) return [ list(map(float, embedding_obj["embedding"])) diff --git a/libs/partners/mistralai/tests/integration_tests/test_embeddings.py b/libs/partners/mistralai/tests/integration_tests/test_embeddings.py index 299feb6f935..3ef91728e2f 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_embeddings.py +++ b/libs/partners/mistralai/tests/integration_tests/test_embeddings.py @@ -1,5 +1,11 @@ """Test MistralAI Embedding.""" +from unittest.mock import patch + +import httpx +import pytest +import tenacity + from langchain_mistralai import MistralAIEmbeddings @@ -29,6 +35,21 @@ async def test_mistralai_embedding_documents_async() -> None: assert len(output[0]) == 1024 +async def test_mistralai_embedding_documents_http_error_async() -> None: + """Test MistralAI embeddings for documents.""" + documents = ["foo bar", "test document"] + embedding = MistralAIEmbeddings(max_retries=0) + mock_response = httpx.Response( + status_code=400, + request=httpx.Request("POST", url=embedding.async_client.base_url), + ) + with ( + patch.object(embedding.async_client, "post", return_value=mock_response), + pytest.raises(tenacity.RetryError), + ): + await embedding.aembed_documents(documents) + + async def test_mistralai_embedding_query_async() -> None: """Test MistralAI embeddings for query.""" document = "foo bar"