mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 00:48:45 +00:00
Added retries for Vertex LLM (#7219)
#7217 --------- Co-authored-by: Leonid Kuligin <kuligin@google.com>
This commit is contained in:
parent
ec66d5188c
commit
8b19f6a0da
@ -1,14 +1,36 @@
|
|||||||
"""Base interface for large language models to expose."""
|
"""Base interface for large language models to expose."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from pydantic import Field, root_validator, validator
|
from pydantic import Field, root_validator, validator
|
||||||
|
from tenacity import (
|
||||||
|
before_sleep_log,
|
||||||
|
retry,
|
||||||
|
retry_base,
|
||||||
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
)
|
||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
@ -29,11 +51,34 @@ from langchain.schema import (
|
|||||||
)
|
)
|
||||||
from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string
|
from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _get_verbosity() -> bool:
|
def _get_verbosity() -> bool:
|
||||||
return langchain.verbose
|
return langchain.verbose
|
||||||
|
|
||||||
|
|
||||||
|
def create_base_retry_decorator(
|
||||||
|
error_types: List[Type[BaseException]], max_retries: int = 1
|
||||||
|
) -> Callable[[Any], Any]:
|
||||||
|
"""Create a retry decorator for a given LLM and provided list of error types."""
|
||||||
|
|
||||||
|
min_seconds = 4
|
||||||
|
max_seconds = 10
|
||||||
|
# Wait 2^x * 1 second between each retry starting with
|
||||||
|
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||||
|
retry_instance: "retry_base" = retry_if_exception_type(error_types[0])
|
||||||
|
for error in error_types[1:]:
|
||||||
|
retry_instance = retry_instance | retry_if_exception_type(error)
|
||||||
|
return retry(
|
||||||
|
reraise=True,
|
||||||
|
stop=stop_after_attempt(max_retries),
|
||||||
|
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||||
|
retry=retry_instance,
|
||||||
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_prompts(
|
def get_prompts(
|
||||||
params: Dict[str, Any], prompts: List[str]
|
params: Dict[str, Any], prompts: List[str]
|
||||||
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
|
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
|
||||||
|
@ -21,19 +21,12 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import Field, root_validator
|
from pydantic import Field, root_validator
|
||||||
from tenacity import (
|
|
||||||
before_sleep_log,
|
|
||||||
retry,
|
|
||||||
retry_if_exception_type,
|
|
||||||
stop_after_attempt,
|
|
||||||
wait_exponential,
|
|
||||||
)
|
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM, create_base_retry_decorator
|
||||||
from langchain.schema import Generation, LLMResult
|
from langchain.schema import Generation, LLMResult
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
@ -76,23 +69,14 @@ def _streaming_response_template() -> Dict[str, Any]:
|
|||||||
def _create_retry_decorator(llm: Union[BaseOpenAI, OpenAIChat]) -> Callable[[Any], Any]:
|
def _create_retry_decorator(llm: Union[BaseOpenAI, OpenAIChat]) -> Callable[[Any], Any]:
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
min_seconds = 4
|
errors = [
|
||||||
max_seconds = 10
|
openai.error.Timeout,
|
||||||
# Wait 2^x * 1 second between each retry starting with
|
openai.error.APIError,
|
||||||
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
openai.error.APIConnectionError,
|
||||||
return retry(
|
openai.error.RateLimitError,
|
||||||
reraise=True,
|
openai.error.ServiceUnavailableError,
|
||||||
stop=stop_after_attempt(llm.max_retries),
|
]
|
||||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
return create_base_retry_decorator(error_types=errors, max_retries=llm.max_retries)
|
||||||
retry=(
|
|
||||||
retry_if_exception_type(openai.error.Timeout)
|
|
||||||
| retry_if_exception_type(openai.error.APIError)
|
|
||||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
|
||||||
| retry_if_exception_type(openai.error.RateLimitError)
|
|
||||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
|
||||||
),
|
|
||||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def completion_with_retry(llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any) -> Any:
|
def completion_with_retry(llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any) -> Any:
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
"""Wrapper around Google VertexAI models."""
|
"""Wrapper around Google VertexAI models."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||||
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
@ -9,7 +11,7 @@ from langchain.callbacks.manager import (
|
|||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM, create_base_retry_decorator
|
||||||
from langchain.llms.utils import enforce_stop_tokens
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
from langchain.utilities.vertexai import (
|
from langchain.utilities.vertexai import (
|
||||||
init_vertexai,
|
init_vertexai,
|
||||||
@ -24,6 +26,32 @@ def is_codey_model(model_name: str) -> bool:
|
|||||||
return "code" in model_name
|
return "code" in model_name
|
||||||
|
|
||||||
|
|
||||||
|
def _create_retry_decorator(llm: VertexAI) -> Callable[[Any], Any]:
|
||||||
|
import google.api_core
|
||||||
|
|
||||||
|
errors = [
|
||||||
|
google.api_core.exceptions.ResourceExhausted,
|
||||||
|
google.api_core.exceptions.ServiceUnavailable,
|
||||||
|
google.api_core.exceptions.Aborted,
|
||||||
|
google.api_core.exceptions.DeadlineExceeded,
|
||||||
|
]
|
||||||
|
decorator = create_base_retry_decorator(
|
||||||
|
error_types=errors, max_retries=llm.max_retries # type: ignore
|
||||||
|
)
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def completion_with_retry(llm: VertexAI, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the completion call."""
|
||||||
|
retry_decorator = _create_retry_decorator(llm)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
return llm.client.predict(*args, **kwargs)
|
||||||
|
|
||||||
|
return _completion_with_retry(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class _VertexAICommon(BaseModel):
|
class _VertexAICommon(BaseModel):
|
||||||
client: "_LanguageModel" = None #: :meta private:
|
client: "_LanguageModel" = None #: :meta private:
|
||||||
model_name: str
|
model_name: str
|
||||||
@ -51,6 +79,8 @@ class _VertexAICommon(BaseModel):
|
|||||||
request_parallelism: int = 5
|
request_parallelism: int = 5
|
||||||
"The amount of parallelism allowed for requests issued to VertexAI models. "
|
"The amount of parallelism allowed for requests issued to VertexAI models. "
|
||||||
"Default is 5."
|
"Default is 5."
|
||||||
|
max_retries: int = 6
|
||||||
|
"""The maximum number of retries to make when generating."""
|
||||||
task_executor: ClassVar[Optional[Executor]] = None
|
task_executor: ClassVar[Optional[Executor]] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -76,7 +106,7 @@ class _VertexAICommon(BaseModel):
|
|||||||
self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
|
self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
|
||||||
) -> str:
|
) -> str:
|
||||||
params = {**self._default_params, **kwargs}
|
params = {**self._default_params, **kwargs}
|
||||||
res = self.client.predict(prompt, **params)
|
res = completion_with_retry(self, prompt, **params) # type: ignore
|
||||||
return self._enforce_stop_words(res.text, stop)
|
return self._enforce_stop_words(res.text, stop)
|
||||||
|
|
||||||
def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str:
|
def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str:
|
||||||
|
Loading…
Reference in New Issue
Block a user