mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 05:30:39 +00:00
Exponential back-off support for Google PaLM api (#4001)
This PR adds exponential back-off to the Google PaLM api to gracefully handle rate limiting errors. --------- Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
a6f3ec94bc
commit
a9dbe90447
@ -1,9 +1,17 @@
|
||||
"""Wrapper around Google's PaLM Chat API."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@ -24,6 +32,8 @@ from langchain.utils import get_from_dict_or_env
|
||||
if TYPE_CHECKING:
|
||||
import google.generativeai as genai
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatGooglePalmError(Exception):
|
||||
pass
|
||||
@ -156,6 +166,51 @@ def _messages_to_prompt_dict(
|
||||
)
|
||||
|
||||
|
||||
def _create_retry_decorator() -> Callable[[Any], Any]:
|
||||
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
|
||||
import google.api_core.exceptions
|
||||
|
||||
multiplier = 2
|
||||
min_seconds = 1
|
||||
max_seconds = 60
|
||||
max_retries = 10
|
||||
|
||||
return retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(max_retries),
|
||||
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
|
||||
retry=(
|
||||
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
|
||||
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
|
||||
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
|
||||
),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
|
||||
|
||||
def chat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator()
|
||||
|
||||
@retry_decorator
|
||||
def _chat_with_retry(**kwargs: Any) -> Any:
|
||||
return llm.client.chat(**kwargs)
|
||||
|
||||
return _chat_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def achat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the async completion call."""
|
||||
retry_decorator = _create_retry_decorator()
|
||||
|
||||
@retry_decorator
|
||||
async def _achat_with_retry(**kwargs: Any) -> Any:
|
||||
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
|
||||
return await llm.client.chat_async(**kwargs)
|
||||
|
||||
return await _achat_with_retry(**kwargs)
|
||||
|
||||
|
||||
class ChatGooglePalm(BaseChatModel, BaseModel):
|
||||
"""Wrapper around Google's PaLM Chat API.
|
||||
|
||||
@ -227,7 +282,8 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
||||
) -> ChatResult:
|
||||
prompt = _messages_to_prompt_dict(messages)
|
||||
|
||||
response: genai.types.ChatResponse = self.client.chat(
|
||||
response: genai.types.ChatResponse = chat_with_retry(
|
||||
self,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
temperature=self.temperature,
|
||||
@ -246,7 +302,8 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
||||
) -> ChatResult:
|
||||
prompt = _messages_to_prompt_dict(messages)
|
||||
|
||||
response: genai.types.ChatResponse = await self.client.chat_async(
|
||||
response: genai.types.ChatResponse = await achat_with_retry(
|
||||
self,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
temperature=self.temperature,
|
||||
|
@ -1,16 +1,64 @@
|
||||
"""Wrapper arround Google's PaLM Embeddings APIs."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_retry_decorator() -> Callable[[Any], Any]:
|
||||
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
|
||||
import google.api_core.exceptions
|
||||
|
||||
multiplier = 2
|
||||
min_seconds = 1
|
||||
max_seconds = 60
|
||||
max_retries = 10
|
||||
|
||||
return retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(max_retries),
|
||||
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
|
||||
retry=(
|
||||
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
|
||||
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
|
||||
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
|
||||
),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
|
||||
|
||||
def embed_with_retry(
|
||||
embeddings: GooglePalmEmbeddings, *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator()
|
||||
|
||||
@retry_decorator
|
||||
def _embed_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
return embeddings.client.generate_embeddings(*args, **kwargs)
|
||||
|
||||
return _embed_with_retry(*args, **kwargs)
|
||||
|
||||
|
||||
class GooglePalmEmbeddings(BaseModel, Embeddings):
|
||||
client: Any
|
||||
google_api_key: Optional[str]
|
||||
model_name: str = "models/embedding-gecko-001"
|
||||
"""Model name to use."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@ -34,5 +82,5 @@ class GooglePalmEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text."""
|
||||
embedding = self.client.generate_embeddings(self.model_name, text)
|
||||
embedding = embed_with_retry(self, self.model_name, text)
|
||||
return embedding["embedding"]
|
||||
|
@ -1,9 +1,17 @@
|
||||
"""Wrapper arround Google's PaLM Text APIs."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@ -13,6 +21,44 @@ from langchain.llms import BaseLLM
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_retry_decorator() -> Callable[[Any], Any]:
|
||||
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
|
||||
try:
|
||||
import google.api_core.exceptions
|
||||
except ImportError:
|
||||
raise ImportError()
|
||||
|
||||
multiplier = 2
|
||||
min_seconds = 1
|
||||
max_seconds = 60
|
||||
max_retries = 10
|
||||
|
||||
return retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(max_retries),
|
||||
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
|
||||
retry=(
|
||||
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
|
||||
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
|
||||
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
|
||||
),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
|
||||
|
||||
def generate_with_retry(llm: GooglePalm, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator()
|
||||
|
||||
@retry_decorator
|
||||
def _generate_with_retry(**kwargs: Any) -> Any:
|
||||
return llm.client.generate_text(**kwargs)
|
||||
|
||||
return _generate_with_retry(**kwargs)
|
||||
|
||||
|
||||
def _strip_erroneous_leading_spaces(text: str) -> str:
|
||||
"""Strip erroneous leading spaces from text.
|
||||
@ -85,7 +131,8 @@ class GooglePalm(BaseLLM, BaseModel):
|
||||
) -> LLMResult:
|
||||
generations = []
|
||||
for prompt in prompts:
|
||||
completion = self.client.generate_text(
|
||||
completion = generate_with_retry(
|
||||
self,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
stop_sequences=stop,
|
||||
|
Loading…
Reference in New Issue
Block a user