diff --git a/libs/langchain/langchain/llms/yandex.py b/libs/langchain/langchain/llms/yandex.py index ba1581b1883..58a2c831685 100644 --- a/libs/langchain/langchain/llms/yandex.py +++ b/libs/langchain/langchain/llms/yandex.py @@ -1,6 +1,9 @@ from typing import Any, Dict, List, Mapping, Optional -from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.load.serializable import Serializable @@ -128,3 +131,75 @@ class YandexGPT(_BaseYandexGPT, LLM): if stop is not None: text = enforce_stop_tokens(text, stop) return text + + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Async call the Yandex GPT model and return the output. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The string generated by the model. + """ + try: + import asyncio + + import grpc + from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value + from yandex.cloud.ai.llm.v1alpha.llm_pb2 import GenerationOptions + from yandex.cloud.ai.llm.v1alpha.llm_service_pb2 import ( + InstructRequest, + InstructResponse, + ) + from yandex.cloud.ai.llm.v1alpha.llm_service_pb2_grpc import ( + TextGenerationAsyncServiceStub, + ) + from yandex.cloud.operation.operation_service_pb2 import GetOperationRequest + from yandex.cloud.operation.operation_service_pb2_grpc import ( + OperationServiceStub, + ) + except ImportError as e: + raise ImportError( + "Please install YandexCloud SDK" " with `pip install yandexcloud`." + ) from e + operation_api_url = "operation.api.cloud.yandex.net:443" + channel_credentials = grpc.ssl_channel_credentials() + async with grpc.aio.secure_channel(self.url, channel_credentials) as channel: + request = InstructRequest( + model=self.model_name, + request_text=prompt, + generation_options=GenerationOptions( + temperature=DoubleValue(value=self.temperature), + max_tokens=Int64Value(value=self.max_tokens), + ), + ) + stub = TextGenerationAsyncServiceStub(channel) + if self.iam_token: + metadata = (("authorization", f"Bearer {self.iam_token}"),) + else: + metadata = (("authorization", f"Api-Key {self.api_key}"),) + operation = await stub.Instruct(request, metadata=metadata) + async with grpc.aio.secure_channel( + operation_api_url, channel_credentials + ) as operation_channel: + operation_stub = OperationServiceStub(operation_channel) + while not operation.done: + await asyncio.sleep(1) + operation_request = GetOperationRequest(operation_id=operation.id) + operation = await operation_stub.Get( + operation_request, metadata=metadata + ) + + instruct_response = InstructResponse() + operation.response.Unpack(instruct_response) + text = instruct_response.alternatives[0].text + if stop is not None: + text = enforce_stop_tokens(text, stop) + return text