Embaas - added backoff retries for network requests (#13679)

Running a large number of requests to Embaas' servers (or any server)
can result in intermittent network failures (both from local and
external network/service issues). This PR implements exponential backoff
retries to help mitigate this issue.
This commit is contained in:
Colin Ulin 2023-12-04 19:21:35 -05:00 committed by GitHub
parent f26d88ca60
commit 9f9cb71d26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,10 +1,11 @@
from typing import Any, Dict, List, Mapping, Optional
import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from requests.adapters import HTTPAdapter, Retry
from typing_extensions import NotRequired, TypedDict
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
from langchain.schema.embeddings import Embeddings
from langchain.utils import get_from_dict_or_env
# Currently supported maximum batch size for embedding requests
@ -51,6 +52,10 @@ class EmbaasEmbeddings(BaseModel, Embeddings):
api_url: str = EMBAAS_API_URL
"""The URL for the embaas embeddings API."""
embaas_api_key: Optional[str] = None
"""max number of retries for requests"""
max_retries: Optional[int] = 3
"""request timeout in seconds"""
timeout: Optional[int] = 30
class Config:
"""Configuration for this pydantic object."""
@ -85,8 +90,22 @@ class EmbaasEmbeddings(BaseModel, Embeddings):
"Content-Type": "application/json",
}
response = requests.post(self.api_url, headers=headers, json=payload)
response.raise_for_status()
session = requests.Session()
retries = Retry(
total=self.max_retries,
backoff_factor=0.5,
allowed_methods=["POST"],
raise_on_status=True,
)
session.mount("http://", HTTPAdapter(max_retries=retries))
session.mount("https://", HTTPAdapter(max_retries=retries))
response = session.post(
self.api_url,
headers=headers,
json=payload,
timeout=self.timeout,
)
parsed_response = response.json()
embeddings = [item["embedding"] for item in parsed_response["data"]]