mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 18:23:59 +00:00
community[patch]: Add retry logic to Yandex GPT API Calls (#14907)
**Description:** Added logic for re-calling the YandexGPT API in case of an error --------- Co-authored-by: Dmitry Tyumentsev <dmitry.tyumentsev@raftds.com>
This commit is contained in:
parent
425e5e1791
commit
50381abc42
@ -1,6 +1,8 @@
|
|||||||
"""Wrapper around YandexGPT chat models."""
|
"""Wrapper around YandexGPT chat models."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional, cast
|
from typing import Any, Callable, Dict, List, Optional, cast
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
@ -14,6 +16,13 @@ from langchain_core.messages import (
|
|||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||||
|
from tenacity import (
|
||||||
|
before_sleep_log,
|
||||||
|
retry,
|
||||||
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain_community.llms.utils import enforce_stop_tokens
|
from langchain_community.llms.utils import enforce_stop_tokens
|
||||||
from langchain_community.llms.yandex import _BaseYandexGPT
|
from langchain_community.llms.yandex import _BaseYandexGPT
|
||||||
@ -80,41 +89,7 @@ class ChatYandexGPT(_BaseYandexGPT, BaseChatModel):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: if the last message in the list is not from human.
|
ValueError: if the last message in the list is not from human.
|
||||||
"""
|
"""
|
||||||
try:
|
text = completion_with_retry(self, messages=messages)
|
||||||
import grpc
|
|
||||||
from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value
|
|
||||||
from yandex.cloud.ai.foundation_models.v1.foundation_models_pb2 import (
|
|
||||||
CompletionOptions,
|
|
||||||
Message,
|
|
||||||
)
|
|
||||||
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501
|
|
||||||
CompletionRequest,
|
|
||||||
)
|
|
||||||
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2_grpc import ( # noqa: E501
|
|
||||||
TextGenerationServiceStub,
|
|
||||||
)
|
|
||||||
except ImportError as e:
|
|
||||||
raise ImportError(
|
|
||||||
"Please install YandexCloud SDK" " with `pip install yandexcloud`."
|
|
||||||
) from e
|
|
||||||
if not messages:
|
|
||||||
raise ValueError(
|
|
||||||
"You should provide at least one message to start the chat!"
|
|
||||||
)
|
|
||||||
message_history = _parse_chat_history(messages)
|
|
||||||
channel_credentials = grpc.ssl_channel_credentials()
|
|
||||||
channel = grpc.secure_channel(self.url, channel_credentials)
|
|
||||||
request = CompletionRequest(
|
|
||||||
model_uri=self.model_uri,
|
|
||||||
completion_options=CompletionOptions(
|
|
||||||
temperature=DoubleValue(value=self.temperature),
|
|
||||||
max_tokens=Int64Value(value=self.max_tokens),
|
|
||||||
),
|
|
||||||
messages=[Message(**message) for message in message_history],
|
|
||||||
)
|
|
||||||
stub = TextGenerationServiceStub(channel)
|
|
||||||
res = stub.Completion(request, metadata=self._grpc_metadata)
|
|
||||||
text = list(res)[0].alternatives[0].message.text
|
|
||||||
text = text if stop is None else enforce_stop_tokens(text, stop)
|
text = text if stop is None else enforce_stop_tokens(text, stop)
|
||||||
message = AIMessage(content=text)
|
message = AIMessage(content=text)
|
||||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
@ -139,62 +114,139 @@ class ChatYandexGPT(_BaseYandexGPT, BaseChatModel):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: if the last message in the list is not from human.
|
ValueError: if the last message in the list is not from human.
|
||||||
"""
|
"""
|
||||||
try:
|
text = await acompletion_with_retry(self, messages=messages)
|
||||||
import asyncio
|
text = text if stop is None else enforce_stop_tokens(text, stop)
|
||||||
|
message = AIMessage(content=text)
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
|
|
||||||
import grpc
|
|
||||||
from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value
|
|
||||||
from yandex.cloud.ai.foundation_models.v1.foundation_models_pb2 import (
|
|
||||||
CompletionOptions,
|
|
||||||
Message,
|
|
||||||
)
|
|
||||||
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501
|
|
||||||
CompletionRequest,
|
|
||||||
CompletionResponse,
|
|
||||||
)
|
|
||||||
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2_grpc import ( # noqa: E501
|
|
||||||
TextGenerationAsyncServiceStub,
|
|
||||||
)
|
|
||||||
from yandex.cloud.operation.operation_service_pb2 import GetOperationRequest
|
|
||||||
from yandex.cloud.operation.operation_service_pb2_grpc import (
|
|
||||||
OperationServiceStub,
|
|
||||||
)
|
|
||||||
except ImportError as e:
|
|
||||||
raise ImportError(
|
|
||||||
"Please install YandexCloud SDK" " with `pip install yandexcloud`."
|
|
||||||
) from e
|
|
||||||
if not messages:
|
|
||||||
raise ValueError(
|
|
||||||
"You should provide at least one message to start the chat!"
|
|
||||||
)
|
|
||||||
message_history = _parse_chat_history(messages)
|
|
||||||
operation_api_url = "operation.api.cloud.yandex.net:443"
|
|
||||||
channel_credentials = grpc.ssl_channel_credentials()
|
|
||||||
async with grpc.aio.secure_channel(self.url, channel_credentials) as channel:
|
|
||||||
request = CompletionRequest(
|
|
||||||
model_uri=self.model_uri,
|
|
||||||
completion_options=CompletionOptions(
|
|
||||||
temperature=DoubleValue(value=self.temperature),
|
|
||||||
max_tokens=Int64Value(value=self.max_tokens),
|
|
||||||
),
|
|
||||||
messages=[Message(**message) for message in message_history],
|
|
||||||
)
|
|
||||||
stub = TextGenerationAsyncServiceStub(channel)
|
|
||||||
operation = await stub.Completion(request, metadata=self._grpc_metadata)
|
|
||||||
async with grpc.aio.secure_channel(
|
|
||||||
operation_api_url, channel_credentials
|
|
||||||
) as operation_channel:
|
|
||||||
operation_stub = OperationServiceStub(operation_channel)
|
|
||||||
while not operation.done:
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
operation_request = GetOperationRequest(operation_id=operation.id)
|
|
||||||
operation = await operation_stub.Get(
|
|
||||||
operation_request, metadata=self._grpc_metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
completion_response = CompletionResponse()
|
def _make_request(
|
||||||
operation.response.Unpack(completion_response)
|
self: ChatYandexGPT,
|
||||||
text = completion_response.alternatives[0].message.text
|
messages: List[BaseMessage],
|
||||||
text = text if stop is None else enforce_stop_tokens(text, stop)
|
) -> str:
|
||||||
message = AIMessage(content=text)
|
try:
|
||||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
import grpc
|
||||||
|
from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value
|
||||||
|
from yandex.cloud.ai.foundation_models.v1.foundation_models_pb2 import (
|
||||||
|
CompletionOptions,
|
||||||
|
Message,
|
||||||
|
)
|
||||||
|
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501
|
||||||
|
CompletionRequest,
|
||||||
|
)
|
||||||
|
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2_grpc import ( # noqa: E501
|
||||||
|
TextGenerationServiceStub,
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install YandexCloud SDK" " with `pip install yandexcloud`."
|
||||||
|
) from e
|
||||||
|
if not messages:
|
||||||
|
raise ValueError("You should provide at least one message to start the chat!")
|
||||||
|
message_history = _parse_chat_history(messages)
|
||||||
|
channel_credentials = grpc.ssl_channel_credentials()
|
||||||
|
channel = grpc.secure_channel(self.url, channel_credentials)
|
||||||
|
request = CompletionRequest(
|
||||||
|
model_uri=self.model_uri,
|
||||||
|
completion_options=CompletionOptions(
|
||||||
|
temperature=DoubleValue(value=self.temperature),
|
||||||
|
max_tokens=Int64Value(value=self.max_tokens),
|
||||||
|
),
|
||||||
|
messages=[Message(**message) for message in message_history],
|
||||||
|
)
|
||||||
|
stub = TextGenerationServiceStub(channel)
|
||||||
|
res = stub.Completion(request, metadata=self._grpc_metadata)
|
||||||
|
return list(res)[0].alternatives[0].message.text
|
||||||
|
|
||||||
|
|
||||||
|
async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> str:
|
||||||
|
try:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value
|
||||||
|
from yandex.cloud.ai.foundation_models.v1.foundation_models_pb2 import (
|
||||||
|
CompletionOptions,
|
||||||
|
Message,
|
||||||
|
)
|
||||||
|
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501
|
||||||
|
CompletionRequest,
|
||||||
|
CompletionResponse,
|
||||||
|
)
|
||||||
|
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2_grpc import ( # noqa: E501
|
||||||
|
TextGenerationAsyncServiceStub,
|
||||||
|
)
|
||||||
|
from yandex.cloud.operation.operation_service_pb2 import GetOperationRequest
|
||||||
|
from yandex.cloud.operation.operation_service_pb2_grpc import (
|
||||||
|
OperationServiceStub,
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install YandexCloud SDK" " with `pip install yandexcloud`."
|
||||||
|
) from e
|
||||||
|
if not messages:
|
||||||
|
raise ValueError("You should provide at least one message to start the chat!")
|
||||||
|
message_history = _parse_chat_history(messages)
|
||||||
|
operation_api_url = "operation.api.cloud.yandex.net:443"
|
||||||
|
channel_credentials = grpc.ssl_channel_credentials()
|
||||||
|
async with grpc.aio.secure_channel(self.url, channel_credentials) as channel:
|
||||||
|
request = CompletionRequest(
|
||||||
|
model_uri=self.model_uri,
|
||||||
|
completion_options=CompletionOptions(
|
||||||
|
temperature=DoubleValue(value=self.temperature),
|
||||||
|
max_tokens=Int64Value(value=self.max_tokens),
|
||||||
|
),
|
||||||
|
messages=[Message(**message) for message in message_history],
|
||||||
|
)
|
||||||
|
stub = TextGenerationAsyncServiceStub(channel)
|
||||||
|
operation = await stub.Completion(request, metadata=self._grpc_metadata)
|
||||||
|
async with grpc.aio.secure_channel(
|
||||||
|
operation_api_url, channel_credentials
|
||||||
|
) as operation_channel:
|
||||||
|
operation_stub = OperationServiceStub(operation_channel)
|
||||||
|
while not operation.done:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
operation_request = GetOperationRequest(operation_id=operation.id)
|
||||||
|
operation = await operation_stub.Get(
|
||||||
|
operation_request, metadata=self._grpc_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
completion_response = CompletionResponse()
|
||||||
|
operation.response.Unpack(completion_response)
|
||||||
|
return completion_response.alternatives[0].message.text
|
||||||
|
|
||||||
|
|
||||||
|
def _create_retry_decorator(llm: ChatYandexGPT) -> Callable[[Any], Any]:
|
||||||
|
from grpc import RpcError
|
||||||
|
|
||||||
|
min_seconds = 1
|
||||||
|
max_seconds = 60
|
||||||
|
return retry(
|
||||||
|
reraise=True,
|
||||||
|
stop=stop_after_attempt(llm.max_retries),
|
||||||
|
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||||
|
retry=(retry_if_exception_type((RpcError))),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def completion_with_retry(llm: ChatYandexGPT, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the completion call."""
|
||||||
|
retry_decorator = _create_retry_decorator(llm)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
def _completion_with_retry(**_kwargs: Any) -> Any:
|
||||||
|
return _make_request(llm, **_kwargs)
|
||||||
|
|
||||||
|
return _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def acompletion_with_retry(llm: ChatYandexGPT, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the async completion call."""
|
||||||
|
retry_decorator = _create_retry_decorator(llm)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
async def _completion_with_retry(**_kwargs: Any) -> Any:
|
||||||
|
return await _amake_request(llm, **_kwargs)
|
||||||
|
|
||||||
|
return await _completion_with_retry(**kwargs)
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
from typing import Any, Dict, List, Mapping, Optional
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
@ -8,9 +11,18 @@ from langchain_core.language_models.llms import LLM
|
|||||||
from langchain_core.load.serializable import Serializable
|
from langchain_core.load.serializable import Serializable
|
||||||
from langchain_core.pydantic_v1 import root_validator
|
from langchain_core.pydantic_v1 import root_validator
|
||||||
from langchain_core.utils import get_from_dict_or_env
|
from langchain_core.utils import get_from_dict_or_env
|
||||||
|
from tenacity import (
|
||||||
|
before_sleep_log,
|
||||||
|
retry,
|
||||||
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain_community.llms.utils import enforce_stop_tokens
|
from langchain_community.llms.utils import enforce_stop_tokens
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class _BaseYandexGPT(Serializable):
|
class _BaseYandexGPT(Serializable):
|
||||||
iam_token: str = ""
|
iam_token: str = ""
|
||||||
@ -38,11 +50,24 @@ class _BaseYandexGPT(Serializable):
|
|||||||
"""Sequences when completion generation will stop."""
|
"""Sequences when completion generation will stop."""
|
||||||
url: str = "llm.api.cloud.yandex.net:443"
|
url: str = "llm.api.cloud.yandex.net:443"
|
||||||
"""The url of the API."""
|
"""The url of the API."""
|
||||||
|
max_retries: int = 6
|
||||||
|
"""Maximum number of retries to make when generating."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
return "yandex_gpt"
|
return "yandex_gpt"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
return {
|
||||||
|
"model_uri": self.model_uri,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"stop": self.stop,
|
||||||
|
"max_retries": self.max_retries,
|
||||||
|
}
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that iam token exists in environment."""
|
"""Validate that iam token exists in environment."""
|
||||||
@ -99,16 +124,6 @@ class YandexGPT(_BaseYandexGPT, LLM):
|
|||||||
yandex_gpt = YandexGPT(iam_token="t1.9eu...", folder_id="b1g...")
|
yandex_gpt = YandexGPT(iam_token="t1.9eu...", folder_id="b1g...")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@property
|
|
||||||
def _identifying_params(self) -> Mapping[str, Any]:
|
|
||||||
"""Get the identifying parameters."""
|
|
||||||
return {
|
|
||||||
"model_uri": self.model_uri,
|
|
||||||
"temperature": self.temperature,
|
|
||||||
"max_tokens": self.max_tokens,
|
|
||||||
"stop": self.stop,
|
|
||||||
}
|
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -130,40 +145,7 @@ class YandexGPT(_BaseYandexGPT, LLM):
|
|||||||
|
|
||||||
response = YandexGPT("Tell me a joke.")
|
response = YandexGPT("Tell me a joke.")
|
||||||
"""
|
"""
|
||||||
try:
|
text = completion_with_retry(self, prompt=prompt)
|
||||||
import grpc
|
|
||||||
from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value
|
|
||||||
from yandex.cloud.ai.foundation_models.v1.foundation_models_pb2 import (
|
|
||||||
CompletionOptions,
|
|
||||||
Message,
|
|
||||||
)
|
|
||||||
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501
|
|
||||||
CompletionRequest,
|
|
||||||
)
|
|
||||||
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2_grpc import ( # noqa: E501
|
|
||||||
TextGenerationServiceStub,
|
|
||||||
)
|
|
||||||
except ImportError as e:
|
|
||||||
raise ImportError(
|
|
||||||
"Please install YandexCloud SDK" " with `pip install yandexcloud`."
|
|
||||||
) from e
|
|
||||||
channel_credentials = grpc.ssl_channel_credentials()
|
|
||||||
channel = grpc.secure_channel(self.url, channel_credentials)
|
|
||||||
request = CompletionRequest(
|
|
||||||
model_uri=self.model_uri,
|
|
||||||
completion_options=CompletionOptions(
|
|
||||||
temperature=DoubleValue(value=self.temperature),
|
|
||||||
max_tokens=Int64Value(value=self.max_tokens),
|
|
||||||
),
|
|
||||||
messages=[Message(role="user", text=prompt)],
|
|
||||||
)
|
|
||||||
stub = TextGenerationServiceStub(channel)
|
|
||||||
if self.iam_token:
|
|
||||||
metadata = (("authorization", f"Bearer {self.iam_token}"),)
|
|
||||||
else:
|
|
||||||
metadata = (("authorization", f"Api-Key {self.api_key}"),)
|
|
||||||
res = stub.Completion(request, metadata=metadata)
|
|
||||||
text = list(res)[0].alternatives[0].message.text
|
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
text = enforce_stop_tokens(text, stop)
|
text = enforce_stop_tokens(text, stop)
|
||||||
return text
|
return text
|
||||||
@ -184,57 +166,133 @@ class YandexGPT(_BaseYandexGPT, LLM):
|
|||||||
Returns:
|
Returns:
|
||||||
The string generated by the model.
|
The string generated by the model.
|
||||||
"""
|
"""
|
||||||
try:
|
text = await acompletion_with_retry(self, prompt=prompt)
|
||||||
import asyncio
|
if stop is not None:
|
||||||
|
text = enforce_stop_tokens(text, stop)
|
||||||
|
return text
|
||||||
|
|
||||||
import grpc
|
|
||||||
from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value
|
|
||||||
from yandex.cloud.ai.foundation_models.v1.foundation_models_pb2 import (
|
|
||||||
CompletionOptions,
|
|
||||||
Message,
|
|
||||||
)
|
|
||||||
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501
|
|
||||||
CompletionRequest,
|
|
||||||
CompletionResponse,
|
|
||||||
)
|
|
||||||
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2_grpc import ( # noqa: E501
|
|
||||||
TextGenerationAsyncServiceStub,
|
|
||||||
)
|
|
||||||
from yandex.cloud.operation.operation_service_pb2 import GetOperationRequest
|
|
||||||
from yandex.cloud.operation.operation_service_pb2_grpc import (
|
|
||||||
OperationServiceStub,
|
|
||||||
)
|
|
||||||
except ImportError as e:
|
|
||||||
raise ImportError(
|
|
||||||
"Please install YandexCloud SDK" " with `pip install yandexcloud`."
|
|
||||||
) from e
|
|
||||||
operation_api_url = "operation.api.cloud.yandex.net:443"
|
|
||||||
channel_credentials = grpc.ssl_channel_credentials()
|
|
||||||
async with grpc.aio.secure_channel(self.url, channel_credentials) as channel:
|
|
||||||
request = CompletionRequest(
|
|
||||||
model_uri=self.model_uri,
|
|
||||||
completion_options=CompletionOptions(
|
|
||||||
temperature=DoubleValue(value=self.temperature),
|
|
||||||
max_tokens=Int64Value(value=self.max_tokens),
|
|
||||||
),
|
|
||||||
messages=[Message(role="user", text=prompt)],
|
|
||||||
)
|
|
||||||
stub = TextGenerationAsyncServiceStub(channel)
|
|
||||||
operation = await stub.Completion(request, metadata=self._grpc_metadata)
|
|
||||||
async with grpc.aio.secure_channel(
|
|
||||||
operation_api_url, channel_credentials
|
|
||||||
) as operation_channel:
|
|
||||||
operation_stub = OperationServiceStub(operation_channel)
|
|
||||||
while not operation.done:
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
operation_request = GetOperationRequest(operation_id=operation.id)
|
|
||||||
operation = await operation_stub.Get(
|
|
||||||
operation_request, metadata=self._grpc_metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
completion_response = CompletionResponse()
|
def _make_request(
|
||||||
operation.response.Unpack(completion_response)
|
self: YandexGPT,
|
||||||
text = completion_response.alternatives[0].message.text
|
prompt: str,
|
||||||
if stop is not None:
|
) -> str:
|
||||||
text = enforce_stop_tokens(text, stop)
|
try:
|
||||||
return text
|
import grpc
|
||||||
|
from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value
|
||||||
|
from yandex.cloud.ai.foundation_models.v1.foundation_models_pb2 import (
|
||||||
|
CompletionOptions,
|
||||||
|
Message,
|
||||||
|
)
|
||||||
|
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501
|
||||||
|
CompletionRequest,
|
||||||
|
)
|
||||||
|
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2_grpc import ( # noqa: E501
|
||||||
|
TextGenerationServiceStub,
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install YandexCloud SDK" " with `pip install yandexcloud`."
|
||||||
|
) from e
|
||||||
|
channel_credentials = grpc.ssl_channel_credentials()
|
||||||
|
channel = grpc.secure_channel(self.url, channel_credentials)
|
||||||
|
request = CompletionRequest(
|
||||||
|
model_uri=self.model_uri,
|
||||||
|
completion_options=CompletionOptions(
|
||||||
|
temperature=DoubleValue(value=self.temperature),
|
||||||
|
max_tokens=Int64Value(value=self.max_tokens),
|
||||||
|
),
|
||||||
|
messages=[Message(role="user", text=prompt)],
|
||||||
|
)
|
||||||
|
stub = TextGenerationServiceStub(channel)
|
||||||
|
res = stub.Completion(request, metadata=self._grpc_metadata)
|
||||||
|
return list(res)[0].alternatives[0].message.text
|
||||||
|
|
||||||
|
|
||||||
|
async def _amake_request(self: YandexGPT, prompt: str) -> str:
|
||||||
|
try:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value
|
||||||
|
from yandex.cloud.ai.foundation_models.v1.foundation_models_pb2 import (
|
||||||
|
CompletionOptions,
|
||||||
|
Message,
|
||||||
|
)
|
||||||
|
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501
|
||||||
|
CompletionRequest,
|
||||||
|
CompletionResponse,
|
||||||
|
)
|
||||||
|
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2_grpc import ( # noqa: E501
|
||||||
|
TextGenerationAsyncServiceStub,
|
||||||
|
)
|
||||||
|
from yandex.cloud.operation.operation_service_pb2 import GetOperationRequest
|
||||||
|
from yandex.cloud.operation.operation_service_pb2_grpc import (
|
||||||
|
OperationServiceStub,
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install YandexCloud SDK" " with `pip install yandexcloud`."
|
||||||
|
) from e
|
||||||
|
operation_api_url = "operation.api.cloud.yandex.net:443"
|
||||||
|
channel_credentials = grpc.ssl_channel_credentials()
|
||||||
|
async with grpc.aio.secure_channel(self.url, channel_credentials) as channel:
|
||||||
|
request = CompletionRequest(
|
||||||
|
model_uri=self.model_uri,
|
||||||
|
completion_options=CompletionOptions(
|
||||||
|
temperature=DoubleValue(value=self.temperature),
|
||||||
|
max_tokens=Int64Value(value=self.max_tokens),
|
||||||
|
),
|
||||||
|
messages=[Message(role="user", text=prompt)],
|
||||||
|
)
|
||||||
|
stub = TextGenerationAsyncServiceStub(channel)
|
||||||
|
operation = await stub.Completion(request, metadata=self._grpc_metadata)
|
||||||
|
async with grpc.aio.secure_channel(
|
||||||
|
operation_api_url, channel_credentials
|
||||||
|
) as operation_channel:
|
||||||
|
operation_stub = OperationServiceStub(operation_channel)
|
||||||
|
while not operation.done:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
operation_request = GetOperationRequest(operation_id=operation.id)
|
||||||
|
operation = await operation_stub.Get(
|
||||||
|
operation_request, metadata=self._grpc_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
completion_response = CompletionResponse()
|
||||||
|
operation.response.Unpack(completion_response)
|
||||||
|
return completion_response.alternatives[0].message.text
|
||||||
|
|
||||||
|
|
||||||
|
def _create_retry_decorator(llm: YandexGPT) -> Callable[[Any], Any]:
|
||||||
|
from grpc import RpcError
|
||||||
|
|
||||||
|
min_seconds = 1
|
||||||
|
max_seconds = 60
|
||||||
|
return retry(
|
||||||
|
reraise=True,
|
||||||
|
stop=stop_after_attempt(llm.max_retries),
|
||||||
|
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||||
|
retry=(retry_if_exception_type((RpcError))),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def completion_with_retry(llm: YandexGPT, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the completion call."""
|
||||||
|
retry_decorator = _create_retry_decorator(llm)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
def _completion_with_retry(**_kwargs: Any) -> Any:
|
||||||
|
return _make_request(llm, **_kwargs)
|
||||||
|
|
||||||
|
return _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def acompletion_with_retry(llm: YandexGPT, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the async completion call."""
|
||||||
|
retry_decorator = _create_retry_decorator(llm)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
async def _completion_with_retry(**_kwargs: Any) -> Any:
|
||||||
|
return await _amake_request(llm, **_kwargs)
|
||||||
|
|
||||||
|
return await _completion_with_retry(**kwargs)
|
||||||
|
Loading…
Reference in New Issue
Block a user