mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
feat(mistralai): remove tenacity retries for embeddings (#33491)
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Iterable
|
from collections.abc import Callable, Iterable
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from httpx import Response
|
from httpx import Response
|
||||||
@@ -57,6 +57,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
|||||||
api_key:
|
api_key:
|
||||||
The API key for the MistralAI API. If not provided, it will be read from the
|
The API key for the MistralAI API. If not provided, it will be read from the
|
||||||
environment variable `MISTRAL_API_KEY`.
|
environment variable `MISTRAL_API_KEY`.
|
||||||
|
max_concurrent_requests: int
|
||||||
max_retries:
|
max_retries:
|
||||||
The number of times to retry a request if it fails.
|
The number of times to retry a request if it fails.
|
||||||
timeout:
|
timeout:
|
||||||
@@ -133,9 +134,9 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
|||||||
default_factory=secret_from_env("MISTRAL_API_KEY", default=""),
|
default_factory=secret_from_env("MISTRAL_API_KEY", default=""),
|
||||||
)
|
)
|
||||||
endpoint: str = "https://api.mistral.ai/v1/"
|
endpoint: str = "https://api.mistral.ai/v1/"
|
||||||
max_retries: int = 5
|
max_retries: int | None = 5
|
||||||
timeout: int = 120
|
timeout: int = 120
|
||||||
wait_time: int = 30
|
wait_time: int | None = 30
|
||||||
max_concurrent_requests: int = 64
|
max_concurrent_requests: int = 64
|
||||||
tokenizer: Tokenizer = Field(default=None)
|
tokenizer: Tokenizer = Field(default=None)
|
||||||
|
|
||||||
@@ -212,6 +213,18 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
|||||||
if batch:
|
if batch:
|
||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
|
def _retry(self, func: Callable) -> Callable:
|
||||||
|
if self.max_retries is None or self.wait_time is None:
|
||||||
|
return func
|
||||||
|
|
||||||
|
return retry(
|
||||||
|
retry=retry_if_exception_type(
|
||||||
|
(httpx.TimeoutException, httpx.HTTPStatusError)
|
||||||
|
),
|
||||||
|
wait=wait_fixed(self.wait_time),
|
||||||
|
stop=stop_after_attempt(self.max_retries),
|
||||||
|
)(func)
|
||||||
|
|
||||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||||
"""Embed a list of document texts.
|
"""Embed a list of document texts.
|
||||||
|
|
||||||
@@ -225,13 +238,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
|||||||
try:
|
try:
|
||||||
batch_responses = []
|
batch_responses = []
|
||||||
|
|
||||||
@retry(
|
@self._retry
|
||||||
retry=retry_if_exception_type(
|
|
||||||
(httpx.TimeoutException, httpx.HTTPStatusError)
|
|
||||||
),
|
|
||||||
wait=wait_fixed(self.wait_time),
|
|
||||||
stop=stop_after_attempt(self.max_retries),
|
|
||||||
)
|
|
||||||
def _embed_batch(batch: list[str]) -> Response:
|
def _embed_batch(batch: list[str]) -> Response:
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
url="/embeddings",
|
url="/embeddings",
|
||||||
@@ -266,13 +273,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
@retry(
|
@self._retry
|
||||||
retry=retry_if_exception_type(
|
|
||||||
(httpx.TimeoutException, httpx.HTTPStatusError)
|
|
||||||
),
|
|
||||||
wait=wait_fixed(self.wait_time),
|
|
||||||
stop=stop_after_attempt(self.max_retries),
|
|
||||||
)
|
|
||||||
async def _aembed_batch(batch: list[str]) -> Response:
|
async def _aembed_batch(batch: list[str]) -> Response:
|
||||||
response = await self.async_client.post(
|
response = await self.async_client.post(
|
||||||
url="/embeddings",
|
url="/embeddings",
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ async def test_mistralai_embedding_documents_async() -> None:
|
|||||||
assert len(output[0]) == 1024
|
assert len(output[0]) == 1024
|
||||||
|
|
||||||
|
|
||||||
async def test_mistralai_embedding_documents_http_error_async() -> None:
|
async def test_mistralai_embedding_documents_tenacity_error_async() -> None:
|
||||||
"""Test MistralAI embeddings for documents."""
|
"""Test MistralAI embeddings for documents."""
|
||||||
documents = ["foo bar", "test document"]
|
documents = ["foo bar", "test document"]
|
||||||
embedding = MistralAIEmbeddings(max_retries=0)
|
embedding = MistralAIEmbeddings(max_retries=0)
|
||||||
@@ -50,6 +50,21 @@ async def test_mistralai_embedding_documents_http_error_async() -> None:
|
|||||||
await embedding.aembed_documents(documents)
|
await embedding.aembed_documents(documents)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_mistralai_embedding_documents_http_error_async() -> None:
|
||||||
|
"""Test MistralAI embeddings for documents."""
|
||||||
|
documents = ["foo bar", "test document"]
|
||||||
|
embedding = MistralAIEmbeddings(max_retries=None)
|
||||||
|
mock_response = httpx.Response(
|
||||||
|
status_code=400,
|
||||||
|
request=httpx.Request("POST", url=embedding.async_client.base_url),
|
||||||
|
)
|
||||||
|
with (
|
||||||
|
patch.object(embedding.async_client, "post", return_value=mock_response),
|
||||||
|
pytest.raises(httpx.HTTPStatusError),
|
||||||
|
):
|
||||||
|
await embedding.aembed_documents(documents)
|
||||||
|
|
||||||
|
|
||||||
async def test_mistralai_embedding_query_async() -> None:
|
async def test_mistralai_embedding_query_async() -> None:
|
||||||
"""Test MistralAI embeddings for query."""
|
"""Test MistralAI embeddings for query."""
|
||||||
document = "foo bar"
|
document = "foo bar"
|
||||||
|
|||||||
Reference in New Issue
Block a user