mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 13:40:46 +00:00
Exponential back-off support for Google PaLM api (#4001)
This PR adds exponential back-off to the Google PaLM api to gracefully handle rate limiting errors. --------- Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
a6f3ec94bc
commit
a9dbe90447
@ -1,9 +1,17 @@
|
|||||||
"""Wrapper around Google's PaLM Chat API."""
|
"""Wrapper around Google's PaLM Chat API."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
from tenacity import (
|
||||||
|
before_sleep_log,
|
||||||
|
retry,
|
||||||
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
@ -24,6 +32,8 @@ from langchain.utils import get_from_dict_or_env
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ChatGooglePalmError(Exception):
|
class ChatGooglePalmError(Exception):
|
||||||
pass
|
pass
|
||||||
@ -156,6 +166,51 @@ def _messages_to_prompt_dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_retry_decorator() -> Callable[[Any], Any]:
|
||||||
|
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
|
||||||
|
import google.api_core.exceptions
|
||||||
|
|
||||||
|
multiplier = 2
|
||||||
|
min_seconds = 1
|
||||||
|
max_seconds = 60
|
||||||
|
max_retries = 10
|
||||||
|
|
||||||
|
return retry(
|
||||||
|
reraise=True,
|
||||||
|
stop=stop_after_attempt(max_retries),
|
||||||
|
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
|
||||||
|
retry=(
|
||||||
|
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
|
||||||
|
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
|
||||||
|
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
|
||||||
|
),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def chat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the completion call."""
|
||||||
|
retry_decorator = _create_retry_decorator()
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
def _chat_with_retry(**kwargs: Any) -> Any:
|
||||||
|
return llm.client.chat(**kwargs)
|
||||||
|
|
||||||
|
return _chat_with_retry(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def achat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the async completion call."""
|
||||||
|
retry_decorator = _create_retry_decorator()
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
async def _achat_with_retry(**kwargs: Any) -> Any:
|
||||||
|
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
|
||||||
|
return await llm.client.chat_async(**kwargs)
|
||||||
|
|
||||||
|
return await _achat_with_retry(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
class ChatGooglePalm(BaseChatModel, BaseModel):
|
class ChatGooglePalm(BaseChatModel, BaseModel):
|
||||||
"""Wrapper around Google's PaLM Chat API.
|
"""Wrapper around Google's PaLM Chat API.
|
||||||
|
|
||||||
@ -227,7 +282,8 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
|||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
prompt = _messages_to_prompt_dict(messages)
|
prompt = _messages_to_prompt_dict(messages)
|
||||||
|
|
||||||
response: genai.types.ChatResponse = self.client.chat(
|
response: genai.types.ChatResponse = chat_with_retry(
|
||||||
|
self,
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
@ -246,7 +302,8 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
|||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
prompt = _messages_to_prompt_dict(messages)
|
prompt = _messages_to_prompt_dict(messages)
|
||||||
|
|
||||||
response: genai.types.ChatResponse = await self.client.chat_async(
|
response: genai.types.ChatResponse = await achat_with_retry(
|
||||||
|
self,
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
|
@ -1,16 +1,64 @@
|
|||||||
"""Wrapper arround Google's PaLM Embeddings APIs."""
|
"""Wrapper arround Google's PaLM Embeddings APIs."""
|
||||||
from typing import Any, Dict, List, Optional
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
from tenacity import (
|
||||||
|
before_sleep_log,
|
||||||
|
retry,
|
||||||
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_retry_decorator() -> Callable[[Any], Any]:
|
||||||
|
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
|
||||||
|
import google.api_core.exceptions
|
||||||
|
|
||||||
|
multiplier = 2
|
||||||
|
min_seconds = 1
|
||||||
|
max_seconds = 60
|
||||||
|
max_retries = 10
|
||||||
|
|
||||||
|
return retry(
|
||||||
|
reraise=True,
|
||||||
|
stop=stop_after_attempt(max_retries),
|
||||||
|
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
|
||||||
|
retry=(
|
||||||
|
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
|
||||||
|
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
|
||||||
|
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
|
||||||
|
),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def embed_with_retry(
|
||||||
|
embeddings: GooglePalmEmbeddings, *args: Any, **kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
"""Use tenacity to retry the completion call."""
|
||||||
|
retry_decorator = _create_retry_decorator()
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
def _embed_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
return embeddings.client.generate_embeddings(*args, **kwargs)
|
||||||
|
|
||||||
|
return _embed_with_retry(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class GooglePalmEmbeddings(BaseModel, Embeddings):
|
class GooglePalmEmbeddings(BaseModel, Embeddings):
|
||||||
client: Any
|
client: Any
|
||||||
google_api_key: Optional[str]
|
google_api_key: Optional[str]
|
||||||
model_name: str = "models/embedding-gecko-001"
|
model_name: str = "models/embedding-gecko-001"
|
||||||
|
"""Model name to use."""
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
@ -34,5 +82,5 @@ class GooglePalmEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
"""Embed query text."""
|
"""Embed query text."""
|
||||||
embedding = self.client.generate_embeddings(self.model_name, text)
|
embedding = embed_with_retry(self, self.model_name, text)
|
||||||
return embedding["embedding"]
|
return embedding["embedding"]
|
||||||
|
@ -1,9 +1,17 @@
|
|||||||
"""Wrapper arround Google's PaLM Text APIs."""
|
"""Wrapper arround Google's PaLM Text APIs."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
import logging
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
from tenacity import (
|
||||||
|
before_sleep_log,
|
||||||
|
retry,
|
||||||
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
@ -13,6 +21,44 @@ from langchain.llms import BaseLLM
|
|||||||
from langchain.schema import Generation, LLMResult
|
from langchain.schema import Generation, LLMResult
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_retry_decorator() -> Callable[[Any], Any]:
|
||||||
|
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
|
||||||
|
try:
|
||||||
|
import google.api_core.exceptions
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError()
|
||||||
|
|
||||||
|
multiplier = 2
|
||||||
|
min_seconds = 1
|
||||||
|
max_seconds = 60
|
||||||
|
max_retries = 10
|
||||||
|
|
||||||
|
return retry(
|
||||||
|
reraise=True,
|
||||||
|
stop=stop_after_attempt(max_retries),
|
||||||
|
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
|
||||||
|
retry=(
|
||||||
|
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
|
||||||
|
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
|
||||||
|
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
|
||||||
|
),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_with_retry(llm: GooglePalm, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the completion call."""
|
||||||
|
retry_decorator = _create_retry_decorator()
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
def _generate_with_retry(**kwargs: Any) -> Any:
|
||||||
|
return llm.client.generate_text(**kwargs)
|
||||||
|
|
||||||
|
return _generate_with_retry(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _strip_erroneous_leading_spaces(text: str) -> str:
|
def _strip_erroneous_leading_spaces(text: str) -> str:
|
||||||
"""Strip erroneous leading spaces from text.
|
"""Strip erroneous leading spaces from text.
|
||||||
@ -85,7 +131,8 @@ class GooglePalm(BaseLLM, BaseModel):
|
|||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
generations = []
|
generations = []
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
completion = self.client.generate_text(
|
completion = generate_with_retry(
|
||||||
|
self,
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stop_sequences=stop,
|
stop_sequences=stop,
|
||||||
|
Loading…
Reference in New Issue
Block a user