From 50381abc42b85508d83a4326be808aa3dc4c1401 Mon Sep 17 00:00:00 2001 From: Dmitry Tyumentsev <56769451+tyumentsev4@users.noreply.github.com> Date: Tue, 19 Dec 2023 07:51:42 -0800 Subject: [PATCH] community[patch]: Add retry logic to Yandex GPT API Calls (#14907) **Description:** Added logic for re-calling the YandexGPT API in case of an error --------- Co-authored-by: Dmitry Tyumentsev --- .../langchain_community/chat_models/yandex.py | 238 ++++++++++------- .../langchain_community/llms/yandex.py | 252 +++++++++++------- 2 files changed, 300 insertions(+), 190 deletions(-) diff --git a/libs/community/langchain_community/chat_models/yandex.py b/libs/community/langchain_community/chat_models/yandex.py index 6a8dc556ffc..42e61a91f5e 100644 --- a/libs/community/langchain_community/chat_models/yandex.py +++ b/libs/community/langchain_community/chat_models/yandex.py @@ -1,6 +1,8 @@ """Wrapper around YandexGPT chat models.""" +from __future__ import annotations + import logging -from typing import Any, Dict, List, Optional, cast +from typing import Any, Callable, Dict, List, Optional, cast from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, @@ -14,6 +16,13 @@ from langchain_core.messages import ( SystemMessage, ) from langchain_core.outputs import ChatGeneration, ChatResult +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) from langchain_community.llms.utils import enforce_stop_tokens from langchain_community.llms.yandex import _BaseYandexGPT @@ -80,41 +89,7 @@ class ChatYandexGPT(_BaseYandexGPT, BaseChatModel): Raises: ValueError: if the last message in the list is not from human. """ - try: - import grpc - from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value - from yandex.cloud.ai.foundation_models.v1.foundation_models_pb2 import ( - CompletionOptions, - Message, - ) - from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501 - CompletionRequest, - ) - from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2_grpc import ( # noqa: E501 - TextGenerationServiceStub, - ) - except ImportError as e: - raise ImportError( - "Please install YandexCloud SDK" " with `pip install yandexcloud`." - ) from e - if not messages: - raise ValueError( - "You should provide at least one message to start the chat!" - ) - message_history = _parse_chat_history(messages) - channel_credentials = grpc.ssl_channel_credentials() - channel = grpc.secure_channel(self.url, channel_credentials) - request = CompletionRequest( - model_uri=self.model_uri, - completion_options=CompletionOptions( - temperature=DoubleValue(value=self.temperature), - max_tokens=Int64Value(value=self.max_tokens), - ), - messages=[Message(**message) for message in message_history], - ) - stub = TextGenerationServiceStub(channel) - res = stub.Completion(request, metadata=self._grpc_metadata) - text = list(res)[0].alternatives[0].message.text + text = completion_with_retry(self, messages=messages) text = text if stop is None else enforce_stop_tokens(text, stop) message = AIMessage(content=text) return ChatResult(generations=[ChatGeneration(message=message)]) @@ -139,62 +114,139 @@ class ChatYandexGPT(_BaseYandexGPT, BaseChatModel): Raises: ValueError: if the last message in the list is not from human. """ - try: - import asyncio + text = await acompletion_with_retry(self, messages=messages) + text = text if stop is None else enforce_stop_tokens(text, stop) + message = AIMessage(content=text) + return ChatResult(generations=[ChatGeneration(message=message)]) - import grpc - from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value - from yandex.cloud.ai.foundation_models.v1.foundation_models_pb2 import ( - CompletionOptions, - Message, - ) - from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501 - CompletionRequest, - CompletionResponse, - ) - from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2_grpc import ( # noqa: E501 - TextGenerationAsyncServiceStub, - ) - from yandex.cloud.operation.operation_service_pb2 import GetOperationRequest - from yandex.cloud.operation.operation_service_pb2_grpc import ( - OperationServiceStub, - ) - except ImportError as e: - raise ImportError( - "Please install YandexCloud SDK" " with `pip install yandexcloud`." - ) from e - if not messages: - raise ValueError( - "You should provide at least one message to start the chat!" - ) - message_history = _parse_chat_history(messages) - operation_api_url = "operation.api.cloud.yandex.net:443" - channel_credentials = grpc.ssl_channel_credentials() - async with grpc.aio.secure_channel(self.url, channel_credentials) as channel: - request = CompletionRequest( - model_uri=self.model_uri, - completion_options=CompletionOptions( - temperature=DoubleValue(value=self.temperature), - max_tokens=Int64Value(value=self.max_tokens), - ), - messages=[Message(**message) for message in message_history], - ) - stub = TextGenerationAsyncServiceStub(channel) - operation = await stub.Completion(request, metadata=self._grpc_metadata) - async with grpc.aio.secure_channel( - operation_api_url, channel_credentials - ) as operation_channel: - operation_stub = OperationServiceStub(operation_channel) - while not operation.done: - await asyncio.sleep(1) - operation_request = GetOperationRequest(operation_id=operation.id) - operation = await operation_stub.Get( - operation_request, metadata=self._grpc_metadata - ) - completion_response = CompletionResponse() - operation.response.Unpack(completion_response) - text = completion_response.alternatives[0].message.text - text = text if stop is None else enforce_stop_tokens(text, stop) - message = AIMessage(content=text) - return ChatResult(generations=[ChatGeneration(message=message)]) +def _make_request( + self: ChatYandexGPT, + messages: List[BaseMessage], +) -> str: + try: + import grpc + from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value + from yandex.cloud.ai.foundation_models.v1.foundation_models_pb2 import ( + CompletionOptions, + Message, + ) + from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501 + CompletionRequest, + ) + from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2_grpc import ( # noqa: E501 + TextGenerationServiceStub, + ) + except ImportError as e: + raise ImportError( + "Please install YandexCloud SDK" " with `pip install yandexcloud`." + ) from e + if not messages: + raise ValueError("You should provide at least one message to start the chat!") + message_history = _parse_chat_history(messages) + channel_credentials = grpc.ssl_channel_credentials() + channel = grpc.secure_channel(self.url, channel_credentials) + request = CompletionRequest( + model_uri=self.model_uri, + completion_options=CompletionOptions( + temperature=DoubleValue(value=self.temperature), + max_tokens=Int64Value(value=self.max_tokens), + ), + messages=[Message(**message) for message in message_history], + ) + stub = TextGenerationServiceStub(channel) + res = stub.Completion(request, metadata=self._grpc_metadata) + return list(res)[0].alternatives[0].message.text + + +async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> str: + try: + import asyncio + + import grpc + from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value + from yandex.cloud.ai.foundation_models.v1.foundation_models_pb2 import ( + CompletionOptions, + Message, + ) + from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501 + CompletionRequest, + CompletionResponse, + ) + from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2_grpc import ( # noqa: E501 + TextGenerationAsyncServiceStub, + ) + from yandex.cloud.operation.operation_service_pb2 import GetOperationRequest + from yandex.cloud.operation.operation_service_pb2_grpc import ( + OperationServiceStub, + ) + except ImportError as e: + raise ImportError( + "Please install YandexCloud SDK" " with `pip install yandexcloud`." + ) from e + if not messages: + raise ValueError("You should provide at least one message to start the chat!") + message_history = _parse_chat_history(messages) + operation_api_url = "operation.api.cloud.yandex.net:443" + channel_credentials = grpc.ssl_channel_credentials() + async with grpc.aio.secure_channel(self.url, channel_credentials) as channel: + request = CompletionRequest( + model_uri=self.model_uri, + completion_options=CompletionOptions( + temperature=DoubleValue(value=self.temperature), + max_tokens=Int64Value(value=self.max_tokens), + ), + messages=[Message(**message) for message in message_history], + ) + stub = TextGenerationAsyncServiceStub(channel) + operation = await stub.Completion(request, metadata=self._grpc_metadata) + async with grpc.aio.secure_channel( + operation_api_url, channel_credentials + ) as operation_channel: + operation_stub = OperationServiceStub(operation_channel) + while not operation.done: + await asyncio.sleep(1) + operation_request = GetOperationRequest(operation_id=operation.id) + operation = await operation_stub.Get( + operation_request, metadata=self._grpc_metadata + ) + + completion_response = CompletionResponse() + operation.response.Unpack(completion_response) + return completion_response.alternatives[0].message.text + + +def _create_retry_decorator(llm: ChatYandexGPT) -> Callable[[Any], Any]: + from grpc import RpcError + + min_seconds = 1 + max_seconds = 60 + return retry( + reraise=True, + stop=stop_after_attempt(llm.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=(retry_if_exception_type((RpcError))), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + +def completion_with_retry(llm: ChatYandexGPT, **kwargs: Any) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator(llm) + + @retry_decorator + def _completion_with_retry(**_kwargs: Any) -> Any: + return _make_request(llm, **_kwargs) + + return _completion_with_retry(**kwargs) + + +async def acompletion_with_retry(llm: ChatYandexGPT, **kwargs: Any) -> Any: + """Use tenacity to retry the async completion call.""" + retry_decorator = _create_retry_decorator(llm) + + @retry_decorator + async def _completion_with_retry(**_kwargs: Any) -> Any: + return await _amake_request(llm, **_kwargs) + + return await _completion_with_retry(**kwargs) diff --git a/libs/community/langchain_community/llms/yandex.py b/libs/community/langchain_community/llms/yandex.py index 9691868ddf3..c07efe68310 100644 --- a/libs/community/langchain_community/llms/yandex.py +++ b/libs/community/langchain_community/llms/yandex.py @@ -1,4 +1,7 @@ -from typing import Any, Dict, List, Mapping, Optional +from __future__ import annotations + +import logging +from typing import Any, Callable, Dict, List, Mapping, Optional from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, @@ -8,9 +11,18 @@ from langchain_core.language_models.llms import LLM from langchain_core.load.serializable import Serializable from langchain_core.pydantic_v1 import root_validator from langchain_core.utils import get_from_dict_or_env +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) from langchain_community.llms.utils import enforce_stop_tokens +logger = logging.getLogger(__name__) + class _BaseYandexGPT(Serializable): iam_token: str = "" @@ -38,11 +50,24 @@ class _BaseYandexGPT(Serializable): """Sequences when completion generation will stop.""" url: str = "llm.api.cloud.yandex.net:443" """The url of the API.""" + max_retries: int = 6 + """Maximum number of retries to make when generating.""" @property def _llm_type(self) -> str: return "yandex_gpt" + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return { + "model_uri": self.model_uri, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "stop": self.stop, + "max_retries": self.max_retries, + } + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that iam token exists in environment.""" @@ -99,16 +124,6 @@ class YandexGPT(_BaseYandexGPT, LLM): yandex_gpt = YandexGPT(iam_token="t1.9eu...", folder_id="b1g...") """ - @property - def _identifying_params(self) -> Mapping[str, Any]: - """Get the identifying parameters.""" - return { - "model_uri": self.model_uri, - "temperature": self.temperature, - "max_tokens": self.max_tokens, - "stop": self.stop, - } - def _call( self, prompt: str, @@ -130,40 +145,7 @@ class YandexGPT(_BaseYandexGPT, LLM): response = YandexGPT("Tell me a joke.") """ - try: - import grpc - from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value - from yandex.cloud.ai.foundation_models.v1.foundation_models_pb2 import ( - CompletionOptions, - Message, - ) - from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501 - CompletionRequest, - ) - from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2_grpc import ( # noqa: E501 - TextGenerationServiceStub, - ) - except ImportError as e: - raise ImportError( - "Please install YandexCloud SDK" " with `pip install yandexcloud`." - ) from e - channel_credentials = grpc.ssl_channel_credentials() - channel = grpc.secure_channel(self.url, channel_credentials) - request = CompletionRequest( - model_uri=self.model_uri, - completion_options=CompletionOptions( - temperature=DoubleValue(value=self.temperature), - max_tokens=Int64Value(value=self.max_tokens), - ), - messages=[Message(role="user", text=prompt)], - ) - stub = TextGenerationServiceStub(channel) - if self.iam_token: - metadata = (("authorization", f"Bearer {self.iam_token}"),) - else: - metadata = (("authorization", f"Api-Key {self.api_key}"),) - res = stub.Completion(request, metadata=metadata) - text = list(res)[0].alternatives[0].message.text + text = completion_with_retry(self, prompt=prompt) if stop is not None: text = enforce_stop_tokens(text, stop) return text @@ -184,57 +166,133 @@ class YandexGPT(_BaseYandexGPT, LLM): Returns: The string generated by the model. """ - try: - import asyncio + text = await acompletion_with_retry(self, prompt=prompt) + if stop is not None: + text = enforce_stop_tokens(text, stop) + return text - import grpc - from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value - from yandex.cloud.ai.foundation_models.v1.foundation_models_pb2 import ( - CompletionOptions, - Message, - ) - from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501 - CompletionRequest, - CompletionResponse, - ) - from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2_grpc import ( # noqa: E501 - TextGenerationAsyncServiceStub, - ) - from yandex.cloud.operation.operation_service_pb2 import GetOperationRequest - from yandex.cloud.operation.operation_service_pb2_grpc import ( - OperationServiceStub, - ) - except ImportError as e: - raise ImportError( - "Please install YandexCloud SDK" " with `pip install yandexcloud`." - ) from e - operation_api_url = "operation.api.cloud.yandex.net:443" - channel_credentials = grpc.ssl_channel_credentials() - async with grpc.aio.secure_channel(self.url, channel_credentials) as channel: - request = CompletionRequest( - model_uri=self.model_uri, - completion_options=CompletionOptions( - temperature=DoubleValue(value=self.temperature), - max_tokens=Int64Value(value=self.max_tokens), - ), - messages=[Message(role="user", text=prompt)], - ) - stub = TextGenerationAsyncServiceStub(channel) - operation = await stub.Completion(request, metadata=self._grpc_metadata) - async with grpc.aio.secure_channel( - operation_api_url, channel_credentials - ) as operation_channel: - operation_stub = OperationServiceStub(operation_channel) - while not operation.done: - await asyncio.sleep(1) - operation_request = GetOperationRequest(operation_id=operation.id) - operation = await operation_stub.Get( - operation_request, metadata=self._grpc_metadata - ) - completion_response = CompletionResponse() - operation.response.Unpack(completion_response) - text = completion_response.alternatives[0].message.text - if stop is not None: - text = enforce_stop_tokens(text, stop) - return text +def _make_request( + self: YandexGPT, + prompt: str, +) -> str: + try: + import grpc + from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value + from yandex.cloud.ai.foundation_models.v1.foundation_models_pb2 import ( + CompletionOptions, + Message, + ) + from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501 + CompletionRequest, + ) + from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2_grpc import ( # noqa: E501 + TextGenerationServiceStub, + ) + except ImportError as e: + raise ImportError( + "Please install YandexCloud SDK" " with `pip install yandexcloud`." + ) from e + channel_credentials = grpc.ssl_channel_credentials() + channel = grpc.secure_channel(self.url, channel_credentials) + request = CompletionRequest( + model_uri=self.model_uri, + completion_options=CompletionOptions( + temperature=DoubleValue(value=self.temperature), + max_tokens=Int64Value(value=self.max_tokens), + ), + messages=[Message(role="user", text=prompt)], + ) + stub = TextGenerationServiceStub(channel) + res = stub.Completion(request, metadata=self._grpc_metadata) + return list(res)[0].alternatives[0].message.text + + +async def _amake_request(self: YandexGPT, prompt: str) -> str: + try: + import asyncio + + import grpc + from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value + from yandex.cloud.ai.foundation_models.v1.foundation_models_pb2 import ( + CompletionOptions, + Message, + ) + from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501 + CompletionRequest, + CompletionResponse, + ) + from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2_grpc import ( # noqa: E501 + TextGenerationAsyncServiceStub, + ) + from yandex.cloud.operation.operation_service_pb2 import GetOperationRequest + from yandex.cloud.operation.operation_service_pb2_grpc import ( + OperationServiceStub, + ) + except ImportError as e: + raise ImportError( + "Please install YandexCloud SDK" " with `pip install yandexcloud`." + ) from e + operation_api_url = "operation.api.cloud.yandex.net:443" + channel_credentials = grpc.ssl_channel_credentials() + async with grpc.aio.secure_channel(self.url, channel_credentials) as channel: + request = CompletionRequest( + model_uri=self.model_uri, + completion_options=CompletionOptions( + temperature=DoubleValue(value=self.temperature), + max_tokens=Int64Value(value=self.max_tokens), + ), + messages=[Message(role="user", text=prompt)], + ) + stub = TextGenerationAsyncServiceStub(channel) + operation = await stub.Completion(request, metadata=self._grpc_metadata) + async with grpc.aio.secure_channel( + operation_api_url, channel_credentials + ) as operation_channel: + operation_stub = OperationServiceStub(operation_channel) + while not operation.done: + await asyncio.sleep(1) + operation_request = GetOperationRequest(operation_id=operation.id) + operation = await operation_stub.Get( + operation_request, metadata=self._grpc_metadata + ) + + completion_response = CompletionResponse() + operation.response.Unpack(completion_response) + return completion_response.alternatives[0].message.text + + +def _create_retry_decorator(llm: YandexGPT) -> Callable[[Any], Any]: + from grpc import RpcError + + min_seconds = 1 + max_seconds = 60 + return retry( + reraise=True, + stop=stop_after_attempt(llm.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=(retry_if_exception_type((RpcError))), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + +def completion_with_retry(llm: YandexGPT, **kwargs: Any) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator(llm) + + @retry_decorator + def _completion_with_retry(**_kwargs: Any) -> Any: + return _make_request(llm, **_kwargs) + + return _completion_with_retry(**kwargs) + + +async def acompletion_with_retry(llm: YandexGPT, **kwargs: Any) -> Any: + """Use tenacity to retry the async completion call.""" + retry_decorator = _create_retry_decorator(llm) + + @retry_decorator + async def _completion_with_retry(**_kwargs: Any) -> Any: + return await _amake_request(llm, **_kwargs) + + return await _completion_with_retry(**kwargs)