Added retries for Vertex LLM (#7219)

#7217

---------

Co-authored-by: Leonid Kuligin <kuligin@google.com>
This commit is contained in:
Leonid Kuligin 2023-07-06 15:38:01 +02:00 committed by GitHub
parent ec66d5188c
commit 8b19f6a0da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 88 additions and 29 deletions

View File

@ -1,14 +1,36 @@
"""Base interface for large language models to expose.""" """Base interface for large language models to expose."""
from __future__ import annotations
import asyncio import asyncio
import inspect import inspect
import json import json
import logging
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union from typing import (
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
)
import yaml import yaml
from pydantic import Field, root_validator, validator from pydantic import Field, root_validator, validator
from tenacity import (
before_sleep_log,
retry,
retry_base,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
import langchain import langchain
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
@ -29,11 +51,34 @@ from langchain.schema import (
) )
from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string
logger = logging.getLogger(__name__)
def _get_verbosity() -> bool: def _get_verbosity() -> bool:
return langchain.verbose return langchain.verbose
def create_base_retry_decorator(
error_types: List[Type[BaseException]], max_retries: int = 1
) -> Callable[[Any], Any]:
"""Create a retry decorator for a given LLM and provided list of error types."""
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
retry_instance: "retry_base" = retry_if_exception_type(error_types[0])
for error in error_types[1:]:
retry_instance = retry_instance | retry_if_exception_type(error)
return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=retry_instance,
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def get_prompts( def get_prompts(
params: Dict[str, Any], prompts: List[str] params: Dict[str, Any], prompts: List[str]
) -> Tuple[Dict[int, List], str, List[int], List[str]]: ) -> Tuple[Dict[int, List], str, List[int], List[str]]:

View File

@ -21,19 +21,12 @@ from typing import (
) )
from pydantic import Field, root_validator from pydantic import Field, root_validator
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM, create_base_retry_decorator
from langchain.schema import Generation, LLMResult from langchain.schema import Generation, LLMResult
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
@ -76,23 +69,14 @@ def _streaming_response_template() -> Dict[str, Any]:
def _create_retry_decorator(llm: Union[BaseOpenAI, OpenAIChat]) -> Callable[[Any], Any]: def _create_retry_decorator(llm: Union[BaseOpenAI, OpenAIChat]) -> Callable[[Any], Any]:
import openai import openai
min_seconds = 4 errors = [
max_seconds = 10 openai.error.Timeout,
# Wait 2^x * 1 second between each retry starting with openai.error.APIError,
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards openai.error.APIConnectionError,
return retry( openai.error.RateLimitError,
reraise=True, openai.error.ServiceUnavailableError,
stop=stop_after_attempt(llm.max_retries), ]
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), return create_base_retry_decorator(error_types=errors, max_retries=llm.max_retries)
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def completion_with_retry(llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any) -> Any: def completion_with_retry(llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any) -> Any:

View File

@ -1,7 +1,9 @@
"""Wrapper around Google VertexAI models.""" """Wrapper around Google VertexAI models."""
from __future__ import annotations
import asyncio import asyncio
from concurrent.futures import Executor, ThreadPoolExecutor from concurrent.futures import Executor, ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
@ -9,7 +11,7 @@ from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain.llms.base import LLM from langchain.llms.base import LLM, create_base_retry_decorator
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from langchain.utilities.vertexai import ( from langchain.utilities.vertexai import (
init_vertexai, init_vertexai,
@ -24,6 +26,32 @@ def is_codey_model(model_name: str) -> bool:
return "code" in model_name return "code" in model_name
def _create_retry_decorator(llm: VertexAI) -> Callable[[Any], Any]:
import google.api_core
errors = [
google.api_core.exceptions.ResourceExhausted,
google.api_core.exceptions.ServiceUnavailable,
google.api_core.exceptions.Aborted,
google.api_core.exceptions.DeadlineExceeded,
]
decorator = create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries # type: ignore
)
return decorator
def completion_with_retry(llm: VertexAI, *args: Any, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm)
@retry_decorator
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
return llm.client.predict(*args, **kwargs)
return _completion_with_retry(*args, **kwargs)
class _VertexAICommon(BaseModel): class _VertexAICommon(BaseModel):
client: "_LanguageModel" = None #: :meta private: client: "_LanguageModel" = None #: :meta private:
model_name: str model_name: str
@ -51,6 +79,8 @@ class _VertexAICommon(BaseModel):
request_parallelism: int = 5 request_parallelism: int = 5
"The amount of parallelism allowed for requests issued to VertexAI models. " "The amount of parallelism allowed for requests issued to VertexAI models. "
"Default is 5." "Default is 5."
max_retries: int = 6
"""The maximum number of retries to make when generating."""
task_executor: ClassVar[Optional[Executor]] = None task_executor: ClassVar[Optional[Executor]] = None
@property @property
@ -76,7 +106,7 @@ class _VertexAICommon(BaseModel):
self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
) -> str: ) -> str:
params = {**self._default_params, **kwargs} params = {**self._default_params, **kwargs}
res = self.client.predict(prompt, **params) res = completion_with_retry(self, prompt, **params) # type: ignore
return self._enforce_stop_words(res.text, stop) return self._enforce_stop_words(res.text, stop)
def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str: def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str: