mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-28 10:39:23 +00:00
Embaas - added backoff retries for network requests (#13679)
Running a large number of requests to Embaas' servers (or any server) can result in intermittent network failures (both from local and external network/service issues). This PR implements exponential backoff retries to help mitigate this issue.
This commit is contained in:
parent
f26d88ca60
commit
9f9cb71d26
@ -1,10 +1,11 @@
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from requests.adapters import HTTPAdapter, Retry
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
# Currently supported maximum batch size for embedding requests
|
||||
@ -51,6 +52,10 @@ class EmbaasEmbeddings(BaseModel, Embeddings):
|
||||
api_url: str = EMBAAS_API_URL
|
||||
"""The URL for the embaas embeddings API."""
|
||||
embaas_api_key: Optional[str] = None
|
||||
"""max number of retries for requests"""
|
||||
max_retries: Optional[int] = 3
|
||||
"""request timeout in seconds"""
|
||||
timeout: Optional[int] = 30
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -85,8 +90,22 @@ class EmbaasEmbeddings(BaseModel, Embeddings):
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = requests.post(self.api_url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
session = requests.Session()
|
||||
retries = Retry(
|
||||
total=self.max_retries,
|
||||
backoff_factor=0.5,
|
||||
allowed_methods=["POST"],
|
||||
raise_on_status=True,
|
||||
)
|
||||
|
||||
session.mount("http://", HTTPAdapter(max_retries=retries))
|
||||
session.mount("https://", HTTPAdapter(max_retries=retries))
|
||||
response = session.post(
|
||||
self.api_url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
parsed_response = response.json()
|
||||
embeddings = [item["embedding"] for item in parsed_response["data"]]
|
||||
|
Loading…
Reference in New Issue
Block a user