From 9f9cb71d26f297fa84a364b8f19f4afba807d204 Mon Sep 17 00:00:00 2001 From: Colin Ulin <47982430+pocketcolin@users.noreply.github.com> Date: Mon, 4 Dec 2023 19:21:35 -0500 Subject: [PATCH] Embaas - added backoff retries for network requests (#13679) Running a large number of requests to Embaas' servers (or any server) can result in intermittent network failures (both from local and external network/service issues). This PR implements exponential backoff retries to help mitigate this issue. --- libs/langchain/langchain/embeddings/embaas.py | 27 ++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/embeddings/embaas.py b/libs/langchain/langchain/embeddings/embaas.py index df7696467e7..4d8e682ef3b 100644 --- a/libs/langchain/langchain/embeddings/embaas.py +++ b/libs/langchain/langchain/embeddings/embaas.py @@ -1,10 +1,11 @@ from typing import Any, Dict, List, Mapping, Optional import requests -from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from requests.adapters import HTTPAdapter, Retry from typing_extensions import NotRequired, TypedDict +from langchain.pydantic_v1 import BaseModel, Extra, root_validator +from langchain.schema.embeddings import Embeddings from langchain.utils import get_from_dict_or_env # Currently supported maximum batch size for embedding requests @@ -51,6 +52,10 @@ class EmbaasEmbeddings(BaseModel, Embeddings): api_url: str = EMBAAS_API_URL """The URL for the embaas embeddings API.""" embaas_api_key: Optional[str] = None + """max number of retries for requests""" + max_retries: Optional[int] = 3 + """request timeout in seconds""" + timeout: Optional[int] = 30 class Config: """Configuration for this pydantic object.""" @@ -85,8 +90,22 @@ class EmbaasEmbeddings(BaseModel, Embeddings): "Content-Type": "application/json", } - response = requests.post(self.api_url, headers=headers, json=payload) - response.raise_for_status() + session = requests.Session() + retries = Retry( + total=self.max_retries, + backoff_factor=0.5, + allowed_methods=["POST"], + raise_on_status=True, + ) + + session.mount("http://", HTTPAdapter(max_retries=retries)) + session.mount("https://", HTTPAdapter(max_retries=retries)) + response = session.post( + self.api_url, + headers=headers, + json=payload, + timeout=self.timeout, + ) parsed_response = response.json() embeddings = [item["embedding"] for item in parsed_response["data"]]