comminity[patch]: fix #25575 YandexGPTs for _grpc_metadata (#25617)

it fixes two issues:

### YGPTs are broken #25575

```
File ....conda/lib/python3.11/site-packages/langchain_community/embeddings/yandex.py:211, in _make_request(self, texts, **kwargs)
..
--> 211 res = stub.TextEmbedding(request, metadata=self._grpc_metadata)  # type: ignore[attr-defined]

AttributeError: 'YandexGPTEmbeddings' object has no attribute '_grpc_metadata'
```
My gut feeling that #23841 is the cause.

I have to drop leading underscore from `_grpc_metadata` for quickfix,
but I just don't know how to do it _pydantic_ enough.

### minor issue:

if we use `api_key`, which is not the best practice the code fails with 

```
File ~/git/...../python3.11/site-packages/langchain_community/embeddings/yandex.py:119, in YandexGPTEmbeddings.validate_environment(cls, values)
...

AttributeError: 'tuple' object has no attribute 'append'
```

- Added new integration test. But it requires YGPT env available and
active account. I don't know how int tests dis\enabled in CI.
 - added small unit tests with mocks. Should be fine.

---------

Co-authored-by: mikhail-khludnev <mikhail_khludnev@rntgroup.com>
This commit is contained in:
Mikhail Khludnev 2024-08-29 04:48:10 +03:00 committed by GitHub
parent 850bf89e48
commit a017f49fd3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 297 additions and 20 deletions

View File

@ -170,7 +170,7 @@ def _make_request(
messages=[Message(**message) for message in message_history], messages=[Message(**message) for message in message_history],
) )
stub = TextGenerationServiceStub(channel) stub = TextGenerationServiceStub(channel)
res = stub.Completion(request, metadata=self._grpc_metadata) res = stub.Completion(request, metadata=self.grpc_metadata)
return list(res)[0].alternatives[0].message.text return list(res)[0].alternatives[0].message.text
@ -229,7 +229,7 @@ async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> st
messages=[Message(**message) for message in message_history], messages=[Message(**message) for message in message_history],
) )
stub = TextGenerationAsyncServiceStub(channel) stub = TextGenerationAsyncServiceStub(channel)
operation = await stub.Completion(request, metadata=self._grpc_metadata) operation = await stub.Completion(request, metadata=self.grpc_metadata)
async with grpc.aio.secure_channel( async with grpc.aio.secure_channel(
operation_api_url, channel_credentials operation_api_url, channel_credentials
) as operation_channel: ) as operation_channel:
@ -239,7 +239,7 @@ async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> st
operation_request = GetOperationRequest(operation_id=operation.id) operation_request = GetOperationRequest(operation_id=operation.id)
operation = await operation_stub.Get( operation = await operation_stub.Get(
operation_request, operation_request,
metadata=self._grpc_metadata, metadata=self.grpc_metadata,
) )
completion_response = CompletionResponse() completion_response = CompletionResponse()

View File

@ -69,7 +69,7 @@ class YandexGPTEmbeddings(BaseModel, Embeddings):
disable_request_logging: bool = False disable_request_logging: bool = False
"""YandexGPT API logs all request data by default. """YandexGPT API logs all request data by default.
If you provide personal data, confidential information, disable logging.""" If you provide personal data, confidential information, disable logging."""
_grpc_metadata: Sequence grpc_metadata: Sequence
class Config: class Config:
allow_population_by_field_name = True allow_population_by_field_name = True
@ -91,15 +91,15 @@ class YandexGPTEmbeddings(BaseModel, Embeddings):
if api_key.get_secret_value() == "" and iam_token.get_secret_value() == "": if api_key.get_secret_value() == "" and iam_token.get_secret_value() == "":
raise ValueError("Either 'YC_API_KEY' or 'YC_IAM_TOKEN' must be provided.") raise ValueError("Either 'YC_API_KEY' or 'YC_IAM_TOKEN' must be provided.")
if values["iam_token"]: if values["iam_token"]:
values["_grpc_metadata"] = [ values["grpc_metadata"] = [
("authorization", f"Bearer {values['iam_token'].get_secret_value()}") ("authorization", f"Bearer {values['iam_token'].get_secret_value()}")
] ]
if values["folder_id"]: if values["folder_id"]:
values["_grpc_metadata"].append(("x-folder-id", values["folder_id"])) values["grpc_metadata"].append(("x-folder-id", values["folder_id"]))
else: else:
values["_grpc_metadata"] = ( values["grpc_metadata"] = [
("authorization", f"Api-Key {values['api_key'].get_secret_value()}"), ("authorization", f"Api-Key {values['api_key'].get_secret_value()}"),
) ]
if not values.get("doc_model_uri"): if not values.get("doc_model_uri"):
if values["folder_id"] == "": if values["folder_id"] == "":
@ -114,7 +114,7 @@ class YandexGPTEmbeddings(BaseModel, Embeddings):
f"emb://{values['folder_id']}/{values['model_name']}/{values['model_version']}" f"emb://{values['folder_id']}/{values['model_name']}/{values['model_version']}"
) )
if values["disable_request_logging"]: if values["disable_request_logging"]:
values["_grpc_metadata"].append( values["grpc_metadata"].append(
( (
"x-data-logging-enabled", "x-data-logging-enabled",
"false", "false",
@ -206,7 +206,7 @@ def _make_request(self: YandexGPTEmbeddings, texts: List[str], **kwargs): # typ
for text in texts: for text in texts:
request = TextEmbeddingRequest(model_uri=model_uri, text=text) request = TextEmbeddingRequest(model_uri=model_uri, text=text)
stub = EmbeddingsServiceStub(channel) stub = EmbeddingsServiceStub(channel)
res = stub.TextEmbedding(request, metadata=self._grpc_metadata) # type: ignore[attr-defined] res = stub.TextEmbedding(request, metadata=self.grpc_metadata) # type: ignore[attr-defined]
result.append(list(res.embedding)) result.append(list(res.embedding))
time.sleep(self.sleep_interval) time.sleep(self.sleep_interval)

View File

@ -57,7 +57,7 @@ class _BaseYandexGPT(Serializable):
disable_request_logging: bool = False disable_request_logging: bool = False
"""YandexGPT API logs all request data by default. """YandexGPT API logs all request data by default.
If you provide personal data, confidential information, disable logging.""" If you provide personal data, confidential information, disable logging."""
_grpc_metadata: Optional[Sequence] = None grpc_metadata: Optional[Sequence] = None
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
@ -92,15 +92,15 @@ class _BaseYandexGPT(Serializable):
raise ValueError("Either 'YC_API_KEY' or 'YC_IAM_TOKEN' must be provided.") raise ValueError("Either 'YC_API_KEY' or 'YC_IAM_TOKEN' must be provided.")
if values["iam_token"]: if values["iam_token"]:
values["_grpc_metadata"] = [ values["grpc_metadata"] = [
("authorization", f"Bearer {values['iam_token'].get_secret_value()}") ("authorization", f"Bearer {values['iam_token'].get_secret_value()}")
] ]
if values["folder_id"]: if values["folder_id"]:
values["_grpc_metadata"].append(("x-folder-id", values["folder_id"])) values["grpc_metadata"].append(("x-folder-id", values["folder_id"]))
else: else:
values["_grpc_metadata"] = ( values["grpc_metadata"] = [
("authorization", f"Api-Key {values['api_key'].get_secret_value()}"), ("authorization", f"Api-Key {values['api_key'].get_secret_value()}"),
) ]
if values["model_uri"] == "" and values["folder_id"] == "": if values["model_uri"] == "" and values["folder_id"] == "":
raise ValueError("Either 'model_uri' or 'folder_id' must be provided.") raise ValueError("Either 'model_uri' or 'folder_id' must be provided.")
if not values["model_uri"]: if not values["model_uri"]:
@ -108,7 +108,7 @@ class _BaseYandexGPT(Serializable):
f"gpt://{values['folder_id']}/{values['model_name']}/{values['model_version']}" f"gpt://{values['folder_id']}/{values['model_name']}/{values['model_version']}"
) )
if values["disable_request_logging"]: if values["disable_request_logging"]:
values["_grpc_metadata"].append( values["grpc_metadata"].append(
( (
"x-data-logging-enabled", "x-data-logging-enabled",
"false", "false",
@ -235,7 +235,7 @@ def _make_request(
messages=[Message(role="user", text=prompt)], messages=[Message(role="user", text=prompt)],
) )
stub = TextGenerationServiceStub(channel) stub = TextGenerationServiceStub(channel)
res = stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined] res = stub.Completion(request, metadata=self.grpc_metadata) # type: ignore[attr-defined]
return list(res)[0].alternatives[0].message.text return list(res)[0].alternatives[0].message.text
@ -291,7 +291,7 @@ async def _amake_request(self: YandexGPT, prompt: str) -> str:
messages=[Message(role="user", text=prompt)], messages=[Message(role="user", text=prompt)],
) )
stub = TextGenerationAsyncServiceStub(channel) stub = TextGenerationAsyncServiceStub(channel)
operation = await stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined] operation = await stub.Completion(request, metadata=self.grpc_metadata) # type: ignore[attr-defined]
async with grpc.aio.secure_channel( async with grpc.aio.secure_channel(
operation_api_url, channel_credentials operation_api_url, channel_credentials
) as operation_channel: ) as operation_channel:
@ -301,7 +301,7 @@ async def _amake_request(self: YandexGPT, prompt: str) -> str:
operation_request = GetOperationRequest(operation_id=operation.id) operation_request = GetOperationRequest(operation_id=operation.id)
operation = await operation_stub.Get( operation = await operation_stub.Get(
operation_request, operation_request,
metadata=self._grpc_metadata, # type: ignore[attr-defined] metadata=self.grpc_metadata, # type: ignore[attr-defined]
) )
completion_response = CompletionResponse() completion_response = CompletionResponse()

View File

@ -0,0 +1,31 @@
import pytest
from langchain_community.embeddings.yandex import YandexGPTEmbeddings
@pytest.mark.parametrize(
"constructor_args",
[
dict(),
dict(disable_request_logging=True),
],
)
# @pytest.mark.scheduled - idk what it means
# requires YC_* env and active service
def test_yandex_embedding(constructor_args: dict) -> None:
documents = ["exactly same", "exactly same", "different"]
embedding = YandexGPTEmbeddings(**constructor_args)
doc_outputs = embedding.embed_documents(documents)
assert len(doc_outputs) == 3
for i in range(3):
assert len(doc_outputs[i]) >= 256 # there are many dims
assert len(doc_outputs[0]) == len(doc_outputs[i]) # dims are te same
assert doc_outputs[0] == doc_outputs[1] # same input, same embeddings
assert doc_outputs[2] != doc_outputs[1] # different input, different embeddings
qry_output = embedding.embed_query(documents[0])
assert len(qry_output) >= 256
assert len(doc_outputs[0]) == len(
qry_output
) # query and doc models have same dimensions
assert doc_outputs[0] != qry_output # query and doc models are different

View File

@ -0,0 +1,92 @@
import os
from unittest import mock
from unittest.mock import MagicMock
import pytest
from langchain_community.chat_models.yandex import ChatYandexGPT
def test_yandexgpt_initialization() -> None:
llm = ChatYandexGPT(
iam_token="your_iam_token", # type: ignore[arg-type]
api_key="your_api_key", # type: ignore[arg-type]
folder_id="your_folder_id",
)
assert llm.model_name == "yandexgpt-lite"
assert llm.model_uri.startswith("gpt://your_folder_id/yandexgpt-lite/")
def test_yandexgpt_model_params() -> None:
llm = ChatYandexGPT(
model_name="custom-model",
model_version="v1",
iam_token="your_iam_token", # type: ignore[arg-type]
api_key="your_api_key", # type: ignore[arg-type]
folder_id="your_folder_id",
)
assert llm.model_name == "custom-model"
assert llm.model_version == "v1"
assert llm.iam_token.get_secret_value() == "your_iam_token"
assert llm.model_uri == "gpt://your_folder_id/custom-model/v1"
def test_yandexgpt_invalid_model_params() -> None:
with pytest.raises(ValueError):
ChatYandexGPT(model_uri="", iam_token="your_iam_token") # type: ignore[arg-type]
with pytest.raises(ValueError):
ChatYandexGPT(
iam_token="", # type: ignore[arg-type]
api_key="your_api_key", # type: ignore[arg-type]
model_uri="",
)
@pytest.mark.parametrize(
"api_key_or_token", [dict(api_key="bogus"), dict(iam_token="bogus")]
)
@pytest.mark.parametrize(
"disable_logging",
[dict(), dict(disable_request_logging=True), dict(disable_request_logging=False)],
)
@mock.patch.dict(os.environ, {}, clear=True)
def test_completion_call(api_key_or_token: dict, disable_logging: dict) -> None:
absent_yandex_module_stub = MagicMock()
grpc_mock = MagicMock()
with mock.patch.dict(
"sys.modules",
{
"yandex.cloud.ai.foundation_models.v1."
"text_common_pb2": absent_yandex_module_stub,
"yandex.cloud.ai.foundation_models.v1.text_generation."
"text_generation_service_pb2": absent_yandex_module_stub,
"yandex.cloud.ai.foundation_models.v1.text_generation."
"text_generation_service_pb2_grpc": absent_yandex_module_stub,
"grpc": grpc_mock,
},
):
grpc_mock.RpcError = Exception
stub = absent_yandex_module_stub.TextGenerationServiceStub
request_stub = absent_yandex_module_stub.CompletionRequest
msg_constructor_stub = absent_yandex_module_stub.Message
args = {"folder_id": "fldr", **api_key_or_token, **disable_logging}
ygpt = ChatYandexGPT(**args)
grpc_call_mock = stub.return_value.Completion
msg_mock = mock.Mock()
msg_mock.message.text = "cmpltn"
res_mock = mock.Mock()
res_mock.alternatives = [msg_mock]
grpc_call_mock.return_value = [res_mock]
act_emb = ygpt.invoke("nomatter")
assert act_emb.content == "cmpltn"
assert len(grpc_call_mock.call_args_list) == 1
once_called_args = grpc_call_mock.call_args_list[0]
act_model_uri = request_stub.call_args_list[0].kwargs["model_uri"]
act_text = msg_constructor_stub.call_args_list[0].kwargs["text"]
act_metadata = once_called_args.kwargs["metadata"]
assert "fldr" in act_model_uri
assert act_text == "nomatter"
assert act_metadata
assert len(act_metadata) > 0
if disable_logging.get("disable_request_logging"):
assert ("x-data-logging-enabled", "false") in act_metadata

View File

@ -1,10 +1,21 @@
import os import os
from unittest import mock
from unittest.mock import MagicMock
import pytest
from langchain_community.embeddings import YandexGPTEmbeddings from langchain_community.embeddings import YandexGPTEmbeddings
YANDEX_MODULE_NAME2 = (
"yandex.cloud.ai.foundation_models.v1.embedding.embedding_service_pb2_grpc"
)
YANDEX_MODULE_NAME = (
"yandex.cloud.ai.foundation_models.v1.embedding.embedding_service_pb2"
)
@mock.patch.dict(os.environ, {"YC_API_KEY": "foo"}, clear=True)
def test_init() -> None: def test_init() -> None:
os.environ["YC_API_KEY"] = "foo"
models = [ models = [
YandexGPTEmbeddings(folder_id="bar"), # type: ignore[call-arg] YandexGPTEmbeddings(folder_id="bar"), # type: ignore[call-arg]
YandexGPTEmbeddings( # type: ignore[call-arg] YandexGPTEmbeddings( # type: ignore[call-arg]
@ -22,3 +33,92 @@ def test_init() -> None:
assert embeddings.doc_model_uri == "emb://bar/text-search-doc/latest" assert embeddings.doc_model_uri == "emb://bar/text-search-doc/latest"
assert embeddings.model_name == "text-search-query" assert embeddings.model_name == "text-search-query"
assert embeddings.doc_model_name == "text-search-doc" assert embeddings.doc_model_name == "text-search-doc"
@pytest.mark.parametrize(
"api_key_or_token", [dict(api_key="bogus"), dict(iam_token="bogus")]
)
@pytest.mark.parametrize(
"disable_logging",
[dict(), dict(disable_request_logging=True), dict(disable_request_logging=False)],
)
@mock.patch.dict(os.environ, {}, clear=True)
def test_query_embedding_call(api_key_or_token: dict, disable_logging: dict) -> None:
absent_yandex_module_stub = MagicMock()
with mock.patch.dict(
"sys.modules",
{
YANDEX_MODULE_NAME: absent_yandex_module_stub,
YANDEX_MODULE_NAME2: absent_yandex_module_stub,
"grpc": MagicMock(),
},
):
stub = absent_yandex_module_stub.EmbeddingsServiceStub
request_stub = absent_yandex_module_stub.TextEmbeddingRequest
args = {"folder_id": "fldr", **api_key_or_token, **disable_logging}
ygpt = YandexGPTEmbeddings(**args)
grpc_call_mock = stub.return_value.TextEmbedding
grpc_call_mock.return_value.embedding = [1, 2, 3]
act_emb = ygpt.embed_query("nomatter")
assert act_emb == [1, 2, 3]
assert len(grpc_call_mock.call_args_list) == 1
once_called_args = grpc_call_mock.call_args_list[0]
act_model_uri = request_stub.call_args_list[0].kwargs["model_uri"]
assert "fldr" in act_model_uri
assert "query" in act_model_uri
assert "doc" not in act_model_uri
act_text = request_stub.call_args_list[0].kwargs["text"]
assert act_text == "nomatter"
act_metadata = once_called_args.kwargs["metadata"]
assert act_metadata
assert len(act_metadata) > 0
if disable_logging.get("disable_request_logging"):
assert ("x-data-logging-enabled", "false") in act_metadata
@pytest.mark.parametrize(
"api_key_or_token", [dict(api_key="bogus"), dict(iam_token="bogus")]
)
@pytest.mark.parametrize(
"disable_logging",
[dict(), dict(disable_request_logging=True), dict(disable_request_logging=False)],
)
@mock.patch.dict(os.environ, {}, clear=True)
def test_doc_embedding_call(api_key_or_token: dict, disable_logging: dict) -> None:
absent_yandex_module_stub = MagicMock()
with mock.patch.dict(
"sys.modules",
{
YANDEX_MODULE_NAME: absent_yandex_module_stub,
YANDEX_MODULE_NAME2: absent_yandex_module_stub,
"grpc": MagicMock(),
},
):
stub = absent_yandex_module_stub.EmbeddingsServiceStub
request_stub = absent_yandex_module_stub.TextEmbeddingRequest
args = {"folder_id": "fldr", **api_key_or_token, **disable_logging}
ygpt = YandexGPTEmbeddings(**args)
grpc_call_mock = stub.return_value.TextEmbedding
foo_emb = mock.Mock()
foo_emb.embedding = [1, 2, 3]
bar_emb = mock.Mock()
bar_emb.embedding = [4, 5, 6]
grpc_call_mock.side_effect = [foo_emb, bar_emb]
act_emb = ygpt.embed_documents(["foo", "bar"])
assert act_emb == [[1, 2, 3], [4, 5, 6]]
assert len(grpc_call_mock.call_args_list) == 2
for i, txt in enumerate(["foo", "bar"]):
act_model_uri = request_stub.call_args_list[i].kwargs["model_uri"]
act_text = request_stub.call_args_list[i].kwargs["text"]
call_args = grpc_call_mock.call_args_list[i]
act_metadata = call_args.kwargs["metadata"]
assert "fldr" in act_model_uri
assert "query" not in act_model_uri
assert "doc" in act_model_uri
assert act_text == txt
assert act_metadata
assert len(act_metadata) > 0
if disable_logging.get("disable_request_logging"):
assert ("x-data-logging-enabled", "false") in call_args.kwargs[
"metadata"
]

View File

@ -1,3 +1,7 @@
import os
from unittest import mock
from unittest.mock import MagicMock
import pytest import pytest
from langchain_community.llms.yandex import YandexGPT from langchain_community.llms.yandex import YandexGPT
@ -36,3 +40,53 @@ def test_yandexgpt_invalid_model_params() -> None:
api_key="your_api_key", # type: ignore[arg-type] api_key="your_api_key", # type: ignore[arg-type]
model_uri="", model_uri="",
) )
@pytest.mark.parametrize(
"api_key_or_token", [dict(api_key="bogus"), dict(iam_token="bogus")]
)
@pytest.mark.parametrize(
"disable_logging",
[dict(), dict(disable_request_logging=True), dict(disable_request_logging=False)],
)
@mock.patch.dict(os.environ, {}, clear=True)
def test_completion_call(api_key_or_token: dict, disable_logging: dict) -> None:
absent_yandex_module_stub = MagicMock()
grpc_mock = MagicMock()
with mock.patch.dict(
"sys.modules",
{
"yandex.cloud.ai.foundation_models.v1."
"text_common_pb2": absent_yandex_module_stub,
"yandex.cloud.ai.foundation_models.v1.text_generation."
"text_generation_service_pb2": absent_yandex_module_stub,
"yandex.cloud.ai.foundation_models.v1.text_generation."
"text_generation_service_pb2_grpc": absent_yandex_module_stub,
"grpc": grpc_mock,
},
):
grpc_mock.RpcError = Exception
stub = absent_yandex_module_stub.TextGenerationServiceStub
request_stub = absent_yandex_module_stub.CompletionRequest
msg_constructor_stub = absent_yandex_module_stub.Message
args = {"folder_id": "fldr", **api_key_or_token, **disable_logging}
ygpt = YandexGPT(**args)
grpc_call_mock = stub.return_value.Completion
msg_mock = mock.Mock()
msg_mock.message.text = "cmpltn"
res_mock = mock.Mock()
res_mock.alternatives = [msg_mock]
grpc_call_mock.return_value = [res_mock]
act_emb = ygpt.invoke("nomatter")
assert act_emb == "cmpltn"
assert len(grpc_call_mock.call_args_list) == 1
once_called_args = grpc_call_mock.call_args_list[0]
act_model_uri = request_stub.call_args_list[0].kwargs["model_uri"]
act_text = msg_constructor_stub.call_args_list[0].kwargs["text"]
act_metadata = once_called_args.kwargs["metadata"]
assert "fldr" in act_model_uri
assert act_text == "nomatter"
assert act_metadata
assert len(act_metadata) > 0
if disable_logging.get("disable_request_logging"):
assert ("x-data-logging-enabled", "false") in act_metadata