mistral[minor]: Added Retrying Mechanism in case of Request Rate Limit Error for MistralAIEmbeddings (#27818)

- **Description:**: In the event of a Rate Limit Error from the
MistralAI server, the response JSON raises a KeyError. To address this,
a simple retry mechanism has been implemented to handle cases where the
request limit is exceeded.
  - **Issue:** #27790

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Mohammad Mohtashim 2024-12-12 03:53:42 +05:00 committed by GitHub
parent df5008fe55
commit a37afbe353
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,6 +4,7 @@ import warnings
from typing import Iterable, List
import httpx
from httpx import Response
from langchain_core.embeddings import Embeddings
from langchain_core.utils import (
secret_from_env,
@ -15,6 +16,7 @@ from pydantic import (
SecretStr,
model_validator,
)
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
from tokenizers import Tokenizer # type: ignore
from typing_extensions import Self
@ -58,6 +60,8 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
The number of times to retry a request if it fails.
timeout: int
The number of seconds to wait for a response before timing out.
wait_time: int
The number of seconds to wait before retrying a request in case of 429 error.
max_concurrent_requests: int
The maximum number of concurrent requests to make to the Mistral API.
@ -128,6 +132,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
endpoint: str = "https://api.mistral.ai/v1/"
max_retries: int = 5
timeout: int = 120
wait_time: int = 30
max_concurrent_requests: int = 64
tokenizer: Tokenizer = Field(default=None)
@ -215,16 +220,26 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
List of embeddings, one for each text.
"""
try:
batch_responses = (
self.client.post(
batch_responses = []
@retry(
retry=retry_if_exception_type(httpx.TimeoutException),
wait=wait_fixed(self.wait_time),
stop=stop_after_attempt(self.max_retries),
)
def _embed_batch(batch: List[str]) -> Response:
response = self.client.post(
url="/embeddings",
json=dict(
model=self.model,
input=batch,
),
)
for batch in self._get_batches(texts)
)
response.raise_for_status()
return response
for batch in self._get_batches(texts):
batch_responses.append(_embed_batch(batch))
return [
list(map(float, embedding_obj["embedding"]))
for response in batch_responses