community[patch]: Update YandexGPT API (#14773)

Update LLMand Chat model to use new api version

---------

Co-authored-by: Dmitry Tyumentsev <dmitry.tyumentsev@raftds.com>
This commit is contained in:
Dmitry Tyumentsev 2023-12-16 03:25:09 +03:00 committed by GitHub
parent eca89f87d8
commit dcead816df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 185 additions and 69 deletions

View File

@ -42,13 +42,20 @@
"Next, you have two authentication options:\n", "Next, you have two authentication options:\n",
"- [IAM token](https://cloud.yandex.com/en/docs/iam/operations/iam-token/create-for-sa).\n", "- [IAM token](https://cloud.yandex.com/en/docs/iam/operations/iam-token/create-for-sa).\n",
" You can specify the token in a constructor parameter `iam_token` or in an environment variable `YC_IAM_TOKEN`.\n", " You can specify the token in a constructor parameter `iam_token` or in an environment variable `YC_IAM_TOKEN`.\n",
"\n",
"- [API key](https://cloud.yandex.com/en/docs/iam/operations/api-key/create)\n", "- [API key](https://cloud.yandex.com/en/docs/iam/operations/api-key/create)\n",
" You can specify the key in a constructor parameter `api_key` or in an environment variable `YC_API_KEY`." " You can specify the key in a constructor parameter `api_key` or in an environment variable `YC_API_KEY`.\n",
"\n",
"In the `model_uri` parameter, specify the model used, see [the documentation](https://cloud.yandex.com/en/docs/yandexgpt/concepts/models#yandexgpt-generation) for more details.\n",
"\n",
"To specify the model you can use `model_uri` parameter, see [the documentation](https://cloud.yandex.com/en/docs/yandexgpt/concepts/models#yandexgpt-generation) for more details.\n",
"\n",
"By default, the latest version of `yandexgpt-lite` is used from the folder specified in the parameter `folder_id` or `YC_FOLDER_ID` environment variable."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 1,
"id": "eba2d63b-f871-4f61-b55f-f6092bdc297a", "id": "eba2d63b-f871-4f61-b55f-f6092bdc297a",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -59,7 +66,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 2,
"id": "75905d9a-dfae-43aa-95b9-a160280e43f7", "id": "75905d9a-dfae-43aa-95b9-a160280e43f7",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -69,17 +76,17 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 3,
"id": "40844fe7-7fe5-4679-b6c9-1b3238807bdc", "id": "40844fe7-7fe5-4679-b6c9-1b3238807bdc",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"AIMessage(content=\"Je t'aime programmer.\")" "AIMessage(content='Je adore le programmement.')"
] ]
}, },
"execution_count": 8, "execution_count": 3,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -113,7 +120,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.18" "version": "3.10.13"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -29,13 +29,20 @@
"Next, you have two authentication options:\n", "Next, you have two authentication options:\n",
"- [IAM token](https://cloud.yandex.com/en/docs/iam/operations/iam-token/create-for-sa).\n", "- [IAM token](https://cloud.yandex.com/en/docs/iam/operations/iam-token/create-for-sa).\n",
" You can specify the token in a constructor parameter `iam_token` or in an environment variable `YC_IAM_TOKEN`.\n", " You can specify the token in a constructor parameter `iam_token` or in an environment variable `YC_IAM_TOKEN`.\n",
"\n",
"- [API key](https://cloud.yandex.com/en/docs/iam/operations/api-key/create)\n", "- [API key](https://cloud.yandex.com/en/docs/iam/operations/api-key/create)\n",
" You can specify the key in a constructor parameter `api_key` or in an environment variable `YC_API_KEY`." " You can specify the key in a constructor parameter `api_key` or in an environment variable `YC_API_KEY`.\n",
"\n",
"In the `model_uri` parameter, specify the model used, see [the documentation](https://cloud.yandex.com/en/docs/yandexgpt/concepts/models#yandexgpt-generation) for more details.\n",
"\n",
"To specify the model you can use `model_uri` parameter, see [the documentation](https://cloud.yandex.com/en/docs/yandexgpt/concepts/models#yandexgpt-generation) for more details.\n",
"\n",
"By default, the latest version of `yandexgpt-lite` is used from the folder specified in the parameter `folder_id` or `YC_FOLDER_ID` environment variable."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 246, "execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -46,7 +53,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 247, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -56,7 +63,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 248, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -65,7 +72,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 249, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -74,16 +81,16 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 250, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"'Moscow'" "'The capital of Russia is Moscow.'"
] ]
}, },
"execution_count": 250, "execution_count": 5,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -111,7 +118,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.18" "version": "3.10.13"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -1,6 +1,6 @@
"""Wrapper around YandexGPT chat models.""" """Wrapper around YandexGPT chat models."""
import logging import logging
from typing import Any, Dict, List, Optional, Tuple, cast from typing import Any, Dict, List, Optional, cast
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
@ -25,14 +25,13 @@ def _parse_message(role: str, text: str) -> Dict:
return {"role": role, "text": text} return {"role": role, "text": text}
def _parse_chat_history(history: List[BaseMessage]) -> Tuple[List[Dict[str, str]], str]: def _parse_chat_history(history: List[BaseMessage]) -> List[Dict[str, str]]:
"""Parse a sequence of messages into history. """Parse a sequence of messages into history.
Returns: Returns:
A tuple of a list of parsed messages and an instruction message for the model. A list of parsed messages.
""" """
chat_history = [] chat_history = []
instruction = ""
for message in history: for message in history:
content = cast(str, message.content) content = cast(str, message.content)
if isinstance(message, HumanMessage): if isinstance(message, HumanMessage):
@ -40,8 +39,8 @@ def _parse_chat_history(history: List[BaseMessage]) -> Tuple[List[Dict[str, str]
if isinstance(message, AIMessage): if isinstance(message, AIMessage):
chat_history.append(_parse_message("assistant", content)) chat_history.append(_parse_message("assistant", content))
if isinstance(message, SystemMessage): if isinstance(message, SystemMessage):
instruction = content chat_history.append(_parse_message("system", content))
return chat_history, instruction return chat_history
class ChatYandexGPT(_BaseYandexGPT, BaseChatModel): class ChatYandexGPT(_BaseYandexGPT, BaseChatModel):
@ -84,9 +83,14 @@ class ChatYandexGPT(_BaseYandexGPT, BaseChatModel):
try: try:
import grpc import grpc
from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value
from yandex.cloud.ai.llm.v1alpha.llm_pb2 import GenerationOptions, Message from yandex.cloud.ai.foundation_models.v1.foundation_models_pb2 import (
from yandex.cloud.ai.llm.v1alpha.llm_service_pb2 import ChatRequest CompletionOptions,
from yandex.cloud.ai.llm.v1alpha.llm_service_pb2_grpc import ( Message,
)
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501
CompletionRequest,
)
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2_grpc import ( # noqa: E501
TextGenerationServiceStub, TextGenerationServiceStub,
) )
except ImportError as e: except ImportError as e:
@ -97,25 +101,20 @@ class ChatYandexGPT(_BaseYandexGPT, BaseChatModel):
raise ValueError( raise ValueError(
"You should provide at least one message to start the chat!" "You should provide at least one message to start the chat!"
) )
message_history, instruction = _parse_chat_history(messages) message_history = _parse_chat_history(messages)
channel_credentials = grpc.ssl_channel_credentials() channel_credentials = grpc.ssl_channel_credentials()
channel = grpc.secure_channel(self.url, channel_credentials) channel = grpc.secure_channel(self.url, channel_credentials)
request = ChatRequest( request = CompletionRequest(
model=self.model_name, model_uri=self.model_uri,
generation_options=GenerationOptions( completion_options=CompletionOptions(
temperature=DoubleValue(value=self.temperature), temperature=DoubleValue(value=self.temperature),
max_tokens=Int64Value(value=self.max_tokens), max_tokens=Int64Value(value=self.max_tokens),
), ),
instruction_text=instruction,
messages=[Message(**message) for message in message_history], messages=[Message(**message) for message in message_history],
) )
stub = TextGenerationServiceStub(channel) stub = TextGenerationServiceStub(channel)
if self.iam_token: res = stub.Completion(request, metadata=self._grpc_metadata)
metadata = (("authorization", f"Bearer {self.iam_token}"),) text = list(res)[0].alternatives[0].message.text
else:
metadata = (("authorization", f"Api-Key {self.api_key}"),)
res = stub.Chat(request, metadata=metadata)
text = list(res)[0].message.text
text = text if stop is None else enforce_stop_tokens(text, stop) text = text if stop is None else enforce_stop_tokens(text, stop)
message = AIMessage(content=text) message = AIMessage(content=text)
return ChatResult(generations=[ChatGeneration(message=message)]) return ChatResult(generations=[ChatGeneration(message=message)])
@ -127,6 +126,75 @@ class ChatYandexGPT(_BaseYandexGPT, BaseChatModel):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
raise NotImplementedError( """Async method to generate next turn in the conversation.
"""YandexGPT doesn't support async requests at the moment."""
Args:
messages: The history of the conversation as a list of messages.
stop: The list of stop words (optional).
run_manager: The CallbackManager for LLM run, it's not used at the moment.
Returns:
The ChatResult that contains outputs generated by the model.
Raises:
ValueError: if the last message in the list is not from human.
"""
try:
import asyncio
import grpc
from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value
from yandex.cloud.ai.foundation_models.v1.foundation_models_pb2 import (
CompletionOptions,
Message,
) )
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501
CompletionRequest,
CompletionResponse,
)
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2_grpc import ( # noqa: E501
TextGenerationAsyncServiceStub,
)
from yandex.cloud.operation.operation_service_pb2 import GetOperationRequest
from yandex.cloud.operation.operation_service_pb2_grpc import (
OperationServiceStub,
)
except ImportError as e:
raise ImportError(
"Please install YandexCloud SDK" " with `pip install yandexcloud`."
) from e
if not messages:
raise ValueError(
"You should provide at least one message to start the chat!"
)
message_history = _parse_chat_history(messages)
operation_api_url = "operation.api.cloud.yandex.net:443"
channel_credentials = grpc.ssl_channel_credentials()
async with grpc.aio.secure_channel(self.url, channel_credentials) as channel:
request = CompletionRequest(
model_uri=self.model_uri,
completion_options=CompletionOptions(
temperature=DoubleValue(value=self.temperature),
max_tokens=Int64Value(value=self.max_tokens),
),
messages=[Message(**message) for message in message_history],
)
stub = TextGenerationAsyncServiceStub(channel)
operation = await stub.Completion(request, metadata=self._grpc_metadata)
async with grpc.aio.secure_channel(
operation_api_url, channel_credentials
) as operation_channel:
operation_stub = OperationServiceStub(operation_channel)
while not operation.done:
await asyncio.sleep(1)
operation_request = GetOperationRequest(operation_id=operation.id)
operation = await operation_stub.Get(
operation_request, metadata=self._grpc_metadata
)
instruct_response = CompletionResponse()
operation.response.Unpack(instruct_response)
text = instruct_response.alternatives[0].message.text
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text

View File

@ -14,13 +14,19 @@ from langchain_community.llms.utils import enforce_stop_tokens
class _BaseYandexGPT(Serializable): class _BaseYandexGPT(Serializable):
iam_token: str = "" iam_token: str = ""
"""Yandex Cloud IAM token for service account """Yandex Cloud IAM token for service or user account
with the `ai.languageModels.user` role""" with the `ai.languageModels.user` role"""
api_key: str = "" api_key: str = ""
"""Yandex Cloud Api Key for service account """Yandex Cloud Api Key for service account
with the `ai.languageModels.user` role""" with the `ai.languageModels.user` role"""
model_name: str = "general" folder_id: str = ""
"""Yandex Cloud folder ID"""
model_uri: str = ""
"""Model uri to use."""
model_name: str = "yandexgpt-lite"
"""Model name to use.""" """Model name to use."""
model_version: str = "latest"
"""Model version to use."""
temperature: float = 0.6 temperature: float = 0.6
"""What sampling temperature to use. """What sampling temperature to use.
Should be a double number between 0 (inclusive) and 1 (inclusive).""" Should be a double number between 0 (inclusive) and 1 (inclusive)."""
@ -45,8 +51,27 @@ class _BaseYandexGPT(Serializable):
values["iam_token"] = iam_token values["iam_token"] = iam_token
api_key = get_from_dict_or_env(values, "api_key", "YC_API_KEY", "") api_key = get_from_dict_or_env(values, "api_key", "YC_API_KEY", "")
values["api_key"] = api_key values["api_key"] = api_key
folder_id = get_from_dict_or_env(values, "folder_id", "YC_FOLDER_ID", "")
values["folder_id"] = folder_id
if api_key == "" and iam_token == "": if api_key == "" and iam_token == "":
raise ValueError("Either 'YC_API_KEY' or 'YC_IAM_TOKEN' must be provided.") raise ValueError("Either 'YC_API_KEY' or 'YC_IAM_TOKEN' must be provided.")
if values["iam_token"]:
values["_grpc_metadata"] = [
("authorization", f"Bearer {values['iam_token']}")
]
if values["folder_id"]:
values["_grpc_metadata"].append(("x-folder-id", values["folder_id"]))
else:
values["_grpc_metadata"] = (
("authorization", f"Api-Key {values['api_key']}"),
)
if values["model_uri"] == "" and values["folder_id"] == "":
raise ValueError("Either 'model_uri' or 'folder_id' must be provided.")
if not values["model_uri"]:
values[
"model_uri"
] = f"gpt://{values['folder_id']}/{values['model_name']}/{values['model_version']}"
return values return values
@ -62,18 +87,23 @@ class YandexGPT(_BaseYandexGPT, LLM):
- You can specify the key in a constructor parameter `api_key` - You can specify the key in a constructor parameter `api_key`
or in an environment variable `YC_API_KEY`. or in an environment variable `YC_API_KEY`.
To use the default model specify the folder ID in a parameter `folder_id`
or in an environment variable `YC_FOLDER_ID`.
Or specify the model URI in a constructor parameter `model_uri`
Example: Example:
.. code-block:: python .. code-block:: python
from langchain_community.llms import YandexGPT from langchain_community.llms import YandexGPT
yandex_gpt = YandexGPT(iam_token="t1.9eu...") yandex_gpt = YandexGPT(iam_token="t1.9eu...", folder_id="b1g...")
""" """
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters.""" """Get the identifying parameters."""
return { return {
"model_name": self.model_name, "model_uri": self.model_uri,
"temperature": self.temperature, "temperature": self.temperature,
"max_tokens": self.max_tokens, "max_tokens": self.max_tokens,
"stop": self.stop, "stop": self.stop,
@ -103,9 +133,14 @@ class YandexGPT(_BaseYandexGPT, LLM):
try: try:
import grpc import grpc
from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value
from yandex.cloud.ai.llm.v1alpha.llm_pb2 import GenerationOptions from yandex.cloud.ai.foundation_models.v1.foundation_models_pb2 import (
from yandex.cloud.ai.llm.v1alpha.llm_service_pb2 import InstructRequest CompletionOptions,
from yandex.cloud.ai.llm.v1alpha.llm_service_pb2_grpc import ( Message,
)
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501
CompletionRequest,
)
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2_grpc import ( # noqa: E501
TextGenerationServiceStub, TextGenerationServiceStub,
) )
except ImportError as e: except ImportError as e:
@ -114,21 +149,21 @@ class YandexGPT(_BaseYandexGPT, LLM):
) from e ) from e
channel_credentials = grpc.ssl_channel_credentials() channel_credentials = grpc.ssl_channel_credentials()
channel = grpc.secure_channel(self.url, channel_credentials) channel = grpc.secure_channel(self.url, channel_credentials)
request = InstructRequest( request = CompletionRequest(
model=self.model_name, model_uri=self.model_uri,
request_text=prompt, completion_options=CompletionOptions(
generation_options=GenerationOptions(
temperature=DoubleValue(value=self.temperature), temperature=DoubleValue(value=self.temperature),
max_tokens=Int64Value(value=self.max_tokens), max_tokens=Int64Value(value=self.max_tokens),
), ),
messages=[Message(role="user", text=prompt)],
) )
stub = TextGenerationServiceStub(channel) stub = TextGenerationServiceStub(channel)
if self.iam_token: if self.iam_token:
metadata = (("authorization", f"Bearer {self.iam_token}"),) metadata = (("authorization", f"Bearer {self.iam_token}"),)
else: else:
metadata = (("authorization", f"Api-Key {self.api_key}"),) metadata = (("authorization", f"Api-Key {self.api_key}"),)
res = stub.Instruct(request, metadata=metadata) res = stub.Completion(request, metadata=metadata)
text = list(res)[0].alternatives[0].text text = list(res)[0].alternatives[0].message.text
if stop is not None: if stop is not None:
text = enforce_stop_tokens(text, stop) text = enforce_stop_tokens(text, stop)
return text return text
@ -154,12 +189,15 @@ class YandexGPT(_BaseYandexGPT, LLM):
import grpc import grpc
from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value
from yandex.cloud.ai.llm.v1alpha.llm_pb2 import GenerationOptions from yandex.cloud.ai.foundation_models.v1.foundation_models_pb2 import (
from yandex.cloud.ai.llm.v1alpha.llm_service_pb2 import ( CompletionOptions,
InstructRequest, Message,
InstructResponse,
) )
from yandex.cloud.ai.llm.v1alpha.llm_service_pb2_grpc import ( from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501
CompletionRequest,
CompletionResponse,
)
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2_grpc import ( # noqa: E501
TextGenerationAsyncServiceStub, TextGenerationAsyncServiceStub,
) )
from yandex.cloud.operation.operation_service_pb2 import GetOperationRequest from yandex.cloud.operation.operation_service_pb2 import GetOperationRequest
@ -173,20 +211,16 @@ class YandexGPT(_BaseYandexGPT, LLM):
operation_api_url = "operation.api.cloud.yandex.net:443" operation_api_url = "operation.api.cloud.yandex.net:443"
channel_credentials = grpc.ssl_channel_credentials() channel_credentials = grpc.ssl_channel_credentials()
async with grpc.aio.secure_channel(self.url, channel_credentials) as channel: async with grpc.aio.secure_channel(self.url, channel_credentials) as channel:
request = InstructRequest( request = CompletionRequest(
model=self.model_name, model_uri=self.model_uri,
request_text=prompt, completion_options=CompletionOptions(
generation_options=GenerationOptions(
temperature=DoubleValue(value=self.temperature), temperature=DoubleValue(value=self.temperature),
max_tokens=Int64Value(value=self.max_tokens), max_tokens=Int64Value(value=self.max_tokens),
), ),
messages=[Message(role="user", text=prompt)],
) )
stub = TextGenerationAsyncServiceStub(channel) stub = TextGenerationAsyncServiceStub(channel)
if self.iam_token: operation = await stub.Completion(request, metadata=self._grpc_metadata)
metadata = (("authorization", f"Bearer {self.iam_token}"),)
else:
metadata = (("authorization", f"Api-Key {self.api_key}"),)
operation = await stub.Instruct(request, metadata=metadata)
async with grpc.aio.secure_channel( async with grpc.aio.secure_channel(
operation_api_url, channel_credentials operation_api_url, channel_credentials
) as operation_channel: ) as operation_channel:
@ -195,12 +229,12 @@ class YandexGPT(_BaseYandexGPT, LLM):
await asyncio.sleep(1) await asyncio.sleep(1)
operation_request = GetOperationRequest(operation_id=operation.id) operation_request = GetOperationRequest(operation_id=operation.id)
operation = await operation_stub.Get( operation = await operation_stub.Get(
operation_request, metadata=metadata operation_request, metadata=self._grpc_metadata
) )
instruct_response = InstructResponse() instruct_response = CompletionResponse()
operation.response.Unpack(instruct_response) operation.response.Unpack(instruct_response)
text = instruct_response.alternatives[0].text text = instruct_response.alternatives[0].message.text
if stop is not None: if stop is not None:
text = enforce_stop_tokens(text, stop) text = enforce_stop_tokens(text, stop)
return text return text