mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-10 23:41:28 +00:00
community[patch]: Update YandexGPT API (#14773)
Update LLMand Chat model to use new api version --------- Co-authored-by: Dmitry Tyumentsev <dmitry.tyumentsev@raftds.com>
This commit is contained in:
committed by
GitHub
parent
eca89f87d8
commit
dcead816df
@@ -14,13 +14,19 @@ from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
class _BaseYandexGPT(Serializable):
|
||||
iam_token: str = ""
|
||||
"""Yandex Cloud IAM token for service account
|
||||
"""Yandex Cloud IAM token for service or user account
|
||||
with the `ai.languageModels.user` role"""
|
||||
api_key: str = ""
|
||||
"""Yandex Cloud Api Key for service account
|
||||
with the `ai.languageModels.user` role"""
|
||||
model_name: str = "general"
|
||||
folder_id: str = ""
|
||||
"""Yandex Cloud folder ID"""
|
||||
model_uri: str = ""
|
||||
"""Model uri to use."""
|
||||
model_name: str = "yandexgpt-lite"
|
||||
"""Model name to use."""
|
||||
model_version: str = "latest"
|
||||
"""Model version to use."""
|
||||
temperature: float = 0.6
|
||||
"""What sampling temperature to use.
|
||||
Should be a double number between 0 (inclusive) and 1 (inclusive)."""
|
||||
@@ -45,8 +51,27 @@ class _BaseYandexGPT(Serializable):
|
||||
values["iam_token"] = iam_token
|
||||
api_key = get_from_dict_or_env(values, "api_key", "YC_API_KEY", "")
|
||||
values["api_key"] = api_key
|
||||
folder_id = get_from_dict_or_env(values, "folder_id", "YC_FOLDER_ID", "")
|
||||
values["folder_id"] = folder_id
|
||||
if api_key == "" and iam_token == "":
|
||||
raise ValueError("Either 'YC_API_KEY' or 'YC_IAM_TOKEN' must be provided.")
|
||||
|
||||
if values["iam_token"]:
|
||||
values["_grpc_metadata"] = [
|
||||
("authorization", f"Bearer {values['iam_token']}")
|
||||
]
|
||||
if values["folder_id"]:
|
||||
values["_grpc_metadata"].append(("x-folder-id", values["folder_id"]))
|
||||
else:
|
||||
values["_grpc_metadata"] = (
|
||||
("authorization", f"Api-Key {values['api_key']}"),
|
||||
)
|
||||
if values["model_uri"] == "" and values["folder_id"] == "":
|
||||
raise ValueError("Either 'model_uri' or 'folder_id' must be provided.")
|
||||
if not values["model_uri"]:
|
||||
values[
|
||||
"model_uri"
|
||||
] = f"gpt://{values['folder_id']}/{values['model_name']}/{values['model_version']}"
|
||||
return values
|
||||
|
||||
|
||||
@@ -62,18 +87,23 @@ class YandexGPT(_BaseYandexGPT, LLM):
|
||||
- You can specify the key in a constructor parameter `api_key`
|
||||
or in an environment variable `YC_API_KEY`.
|
||||
|
||||
To use the default model specify the folder ID in a parameter `folder_id`
|
||||
or in an environment variable `YC_FOLDER_ID`.
|
||||
|
||||
Or specify the model URI in a constructor parameter `model_uri`
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms import YandexGPT
|
||||
yandex_gpt = YandexGPT(iam_token="t1.9eu...")
|
||||
yandex_gpt = YandexGPT(iam_token="t1.9eu...", folder_id="b1g...")
|
||||
"""
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"model_uri": self.model_uri,
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stop": self.stop,
|
||||
@@ -103,9 +133,14 @@ class YandexGPT(_BaseYandexGPT, LLM):
|
||||
try:
|
||||
import grpc
|
||||
from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value
|
||||
from yandex.cloud.ai.llm.v1alpha.llm_pb2 import GenerationOptions
|
||||
from yandex.cloud.ai.llm.v1alpha.llm_service_pb2 import InstructRequest
|
||||
from yandex.cloud.ai.llm.v1alpha.llm_service_pb2_grpc import (
|
||||
from yandex.cloud.ai.foundation_models.v1.foundation_models_pb2 import (
|
||||
CompletionOptions,
|
||||
Message,
|
||||
)
|
||||
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501
|
||||
CompletionRequest,
|
||||
)
|
||||
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2_grpc import ( # noqa: E501
|
||||
TextGenerationServiceStub,
|
||||
)
|
||||
except ImportError as e:
|
||||
@@ -114,21 +149,21 @@ class YandexGPT(_BaseYandexGPT, LLM):
|
||||
) from e
|
||||
channel_credentials = grpc.ssl_channel_credentials()
|
||||
channel = grpc.secure_channel(self.url, channel_credentials)
|
||||
request = InstructRequest(
|
||||
model=self.model_name,
|
||||
request_text=prompt,
|
||||
generation_options=GenerationOptions(
|
||||
request = CompletionRequest(
|
||||
model_uri=self.model_uri,
|
||||
completion_options=CompletionOptions(
|
||||
temperature=DoubleValue(value=self.temperature),
|
||||
max_tokens=Int64Value(value=self.max_tokens),
|
||||
),
|
||||
messages=[Message(role="user", text=prompt)],
|
||||
)
|
||||
stub = TextGenerationServiceStub(channel)
|
||||
if self.iam_token:
|
||||
metadata = (("authorization", f"Bearer {self.iam_token}"),)
|
||||
else:
|
||||
metadata = (("authorization", f"Api-Key {self.api_key}"),)
|
||||
res = stub.Instruct(request, metadata=metadata)
|
||||
text = list(res)[0].alternatives[0].text
|
||||
res = stub.Completion(request, metadata=metadata)
|
||||
text = list(res)[0].alternatives[0].message.text
|
||||
if stop is not None:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
return text
|
||||
@@ -154,12 +189,15 @@ class YandexGPT(_BaseYandexGPT, LLM):
|
||||
|
||||
import grpc
|
||||
from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value
|
||||
from yandex.cloud.ai.llm.v1alpha.llm_pb2 import GenerationOptions
|
||||
from yandex.cloud.ai.llm.v1alpha.llm_service_pb2 import (
|
||||
InstructRequest,
|
||||
InstructResponse,
|
||||
from yandex.cloud.ai.foundation_models.v1.foundation_models_pb2 import (
|
||||
CompletionOptions,
|
||||
Message,
|
||||
)
|
||||
from yandex.cloud.ai.llm.v1alpha.llm_service_pb2_grpc import (
|
||||
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
)
|
||||
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2_grpc import ( # noqa: E501
|
||||
TextGenerationAsyncServiceStub,
|
||||
)
|
||||
from yandex.cloud.operation.operation_service_pb2 import GetOperationRequest
|
||||
@@ -173,20 +211,16 @@ class YandexGPT(_BaseYandexGPT, LLM):
|
||||
operation_api_url = "operation.api.cloud.yandex.net:443"
|
||||
channel_credentials = grpc.ssl_channel_credentials()
|
||||
async with grpc.aio.secure_channel(self.url, channel_credentials) as channel:
|
||||
request = InstructRequest(
|
||||
model=self.model_name,
|
||||
request_text=prompt,
|
||||
generation_options=GenerationOptions(
|
||||
request = CompletionRequest(
|
||||
model_uri=self.model_uri,
|
||||
completion_options=CompletionOptions(
|
||||
temperature=DoubleValue(value=self.temperature),
|
||||
max_tokens=Int64Value(value=self.max_tokens),
|
||||
),
|
||||
messages=[Message(role="user", text=prompt)],
|
||||
)
|
||||
stub = TextGenerationAsyncServiceStub(channel)
|
||||
if self.iam_token:
|
||||
metadata = (("authorization", f"Bearer {self.iam_token}"),)
|
||||
else:
|
||||
metadata = (("authorization", f"Api-Key {self.api_key}"),)
|
||||
operation = await stub.Instruct(request, metadata=metadata)
|
||||
operation = await stub.Completion(request, metadata=self._grpc_metadata)
|
||||
async with grpc.aio.secure_channel(
|
||||
operation_api_url, channel_credentials
|
||||
) as operation_channel:
|
||||
@@ -195,12 +229,12 @@ class YandexGPT(_BaseYandexGPT, LLM):
|
||||
await asyncio.sleep(1)
|
||||
operation_request = GetOperationRequest(operation_id=operation.id)
|
||||
operation = await operation_stub.Get(
|
||||
operation_request, metadata=metadata
|
||||
operation_request, metadata=self._grpc_metadata
|
||||
)
|
||||
|
||||
instruct_response = InstructResponse()
|
||||
instruct_response = CompletionResponse()
|
||||
operation.response.Unpack(instruct_response)
|
||||
text = instruct_response.alternatives[0].text
|
||||
text = instruct_response.alternatives[0].message.text
|
||||
if stop is not None:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
return text
|
||||
|
Reference in New Issue
Block a user