comminity[patch]: fix #25575 YandexGPTs for _grpc_metadata (#25617)

it fixes two issues:

### YGPTs are broken #25575

```
File ....conda/lib/python3.11/site-packages/langchain_community/embeddings/yandex.py:211, in _make_request(self, texts, **kwargs)
..
--> 211 res = stub.TextEmbedding(request, metadata=self._grpc_metadata)  # type: ignore[attr-defined]

AttributeError: 'YandexGPTEmbeddings' object has no attribute '_grpc_metadata'
```
My gut feeling that #23841 is the cause.

I have to drop leading underscore from `_grpc_metadata` for quickfix,
but I just don't know how to do it _pydantic_ enough.

### minor issue:

if we use `api_key`, which is not the best practice the code fails with 

```
File ~/git/...../python3.11/site-packages/langchain_community/embeddings/yandex.py:119, in YandexGPTEmbeddings.validate_environment(cls, values)
...

AttributeError: 'tuple' object has no attribute 'append'
```

- Added new integration test. But it requires YGPT env available and
active account. I don't know how int tests dis\enabled in CI.
 - added small unit tests with mocks. Should be fine.

---------

Co-authored-by: mikhail-khludnev <mikhail_khludnev@rntgroup.com>
This commit is contained in:
Mikhail Khludnev
2024-08-29 04:48:10 +03:00
committed by GitHub
parent 850bf89e48
commit a017f49fd3
7 changed files with 297 additions and 20 deletions

View File

@@ -0,0 +1,31 @@
import pytest
from langchain_community.embeddings.yandex import YandexGPTEmbeddings
@pytest.mark.parametrize(
"constructor_args",
[
dict(),
dict(disable_request_logging=True),
],
)
# @pytest.mark.scheduled - idk what it means
# requires YC_* env and active service
def test_yandex_embedding(constructor_args: dict) -> None:
documents = ["exactly same", "exactly same", "different"]
embedding = YandexGPTEmbeddings(**constructor_args)
doc_outputs = embedding.embed_documents(documents)
assert len(doc_outputs) == 3
for i in range(3):
assert len(doc_outputs[i]) >= 256 # there are many dims
assert len(doc_outputs[0]) == len(doc_outputs[i]) # dims are te same
assert doc_outputs[0] == doc_outputs[1] # same input, same embeddings
assert doc_outputs[2] != doc_outputs[1] # different input, different embeddings
qry_output = embedding.embed_query(documents[0])
assert len(qry_output) >= 256
assert len(doc_outputs[0]) == len(
qry_output
) # query and doc models have same dimensions
assert doc_outputs[0] != qry_output # query and doc models are different

View File

@@ -0,0 +1,92 @@
import os
from unittest import mock
from unittest.mock import MagicMock
import pytest
from langchain_community.chat_models.yandex import ChatYandexGPT
def test_yandexgpt_initialization() -> None:
llm = ChatYandexGPT(
iam_token="your_iam_token", # type: ignore[arg-type]
api_key="your_api_key", # type: ignore[arg-type]
folder_id="your_folder_id",
)
assert llm.model_name == "yandexgpt-lite"
assert llm.model_uri.startswith("gpt://your_folder_id/yandexgpt-lite/")
def test_yandexgpt_model_params() -> None:
llm = ChatYandexGPT(
model_name="custom-model",
model_version="v1",
iam_token="your_iam_token", # type: ignore[arg-type]
api_key="your_api_key", # type: ignore[arg-type]
folder_id="your_folder_id",
)
assert llm.model_name == "custom-model"
assert llm.model_version == "v1"
assert llm.iam_token.get_secret_value() == "your_iam_token"
assert llm.model_uri == "gpt://your_folder_id/custom-model/v1"
def test_yandexgpt_invalid_model_params() -> None:
with pytest.raises(ValueError):
ChatYandexGPT(model_uri="", iam_token="your_iam_token") # type: ignore[arg-type]
with pytest.raises(ValueError):
ChatYandexGPT(
iam_token="", # type: ignore[arg-type]
api_key="your_api_key", # type: ignore[arg-type]
model_uri="",
)
@pytest.mark.parametrize(
"api_key_or_token", [dict(api_key="bogus"), dict(iam_token="bogus")]
)
@pytest.mark.parametrize(
"disable_logging",
[dict(), dict(disable_request_logging=True), dict(disable_request_logging=False)],
)
@mock.patch.dict(os.environ, {}, clear=True)
def test_completion_call(api_key_or_token: dict, disable_logging: dict) -> None:
absent_yandex_module_stub = MagicMock()
grpc_mock = MagicMock()
with mock.patch.dict(
"sys.modules",
{
"yandex.cloud.ai.foundation_models.v1."
"text_common_pb2": absent_yandex_module_stub,
"yandex.cloud.ai.foundation_models.v1.text_generation."
"text_generation_service_pb2": absent_yandex_module_stub,
"yandex.cloud.ai.foundation_models.v1.text_generation."
"text_generation_service_pb2_grpc": absent_yandex_module_stub,
"grpc": grpc_mock,
},
):
grpc_mock.RpcError = Exception
stub = absent_yandex_module_stub.TextGenerationServiceStub
request_stub = absent_yandex_module_stub.CompletionRequest
msg_constructor_stub = absent_yandex_module_stub.Message
args = {"folder_id": "fldr", **api_key_or_token, **disable_logging}
ygpt = ChatYandexGPT(**args)
grpc_call_mock = stub.return_value.Completion
msg_mock = mock.Mock()
msg_mock.message.text = "cmpltn"
res_mock = mock.Mock()
res_mock.alternatives = [msg_mock]
grpc_call_mock.return_value = [res_mock]
act_emb = ygpt.invoke("nomatter")
assert act_emb.content == "cmpltn"
assert len(grpc_call_mock.call_args_list) == 1
once_called_args = grpc_call_mock.call_args_list[0]
act_model_uri = request_stub.call_args_list[0].kwargs["model_uri"]
act_text = msg_constructor_stub.call_args_list[0].kwargs["text"]
act_metadata = once_called_args.kwargs["metadata"]
assert "fldr" in act_model_uri
assert act_text == "nomatter"
assert act_metadata
assert len(act_metadata) > 0
if disable_logging.get("disable_request_logging"):
assert ("x-data-logging-enabled", "false") in act_metadata

View File

@@ -1,10 +1,21 @@
import os
from unittest import mock
from unittest.mock import MagicMock
import pytest
from langchain_community.embeddings import YandexGPTEmbeddings
YANDEX_MODULE_NAME2 = (
"yandex.cloud.ai.foundation_models.v1.embedding.embedding_service_pb2_grpc"
)
YANDEX_MODULE_NAME = (
"yandex.cloud.ai.foundation_models.v1.embedding.embedding_service_pb2"
)
@mock.patch.dict(os.environ, {"YC_API_KEY": "foo"}, clear=True)
def test_init() -> None:
os.environ["YC_API_KEY"] = "foo"
models = [
YandexGPTEmbeddings(folder_id="bar"), # type: ignore[call-arg]
YandexGPTEmbeddings( # type: ignore[call-arg]
@@ -22,3 +33,92 @@ def test_init() -> None:
assert embeddings.doc_model_uri == "emb://bar/text-search-doc/latest"
assert embeddings.model_name == "text-search-query"
assert embeddings.doc_model_name == "text-search-doc"
@pytest.mark.parametrize(
"api_key_or_token", [dict(api_key="bogus"), dict(iam_token="bogus")]
)
@pytest.mark.parametrize(
"disable_logging",
[dict(), dict(disable_request_logging=True), dict(disable_request_logging=False)],
)
@mock.patch.dict(os.environ, {}, clear=True)
def test_query_embedding_call(api_key_or_token: dict, disable_logging: dict) -> None:
absent_yandex_module_stub = MagicMock()
with mock.patch.dict(
"sys.modules",
{
YANDEX_MODULE_NAME: absent_yandex_module_stub,
YANDEX_MODULE_NAME2: absent_yandex_module_stub,
"grpc": MagicMock(),
},
):
stub = absent_yandex_module_stub.EmbeddingsServiceStub
request_stub = absent_yandex_module_stub.TextEmbeddingRequest
args = {"folder_id": "fldr", **api_key_or_token, **disable_logging}
ygpt = YandexGPTEmbeddings(**args)
grpc_call_mock = stub.return_value.TextEmbedding
grpc_call_mock.return_value.embedding = [1, 2, 3]
act_emb = ygpt.embed_query("nomatter")
assert act_emb == [1, 2, 3]
assert len(grpc_call_mock.call_args_list) == 1
once_called_args = grpc_call_mock.call_args_list[0]
act_model_uri = request_stub.call_args_list[0].kwargs["model_uri"]
assert "fldr" in act_model_uri
assert "query" in act_model_uri
assert "doc" not in act_model_uri
act_text = request_stub.call_args_list[0].kwargs["text"]
assert act_text == "nomatter"
act_metadata = once_called_args.kwargs["metadata"]
assert act_metadata
assert len(act_metadata) > 0
if disable_logging.get("disable_request_logging"):
assert ("x-data-logging-enabled", "false") in act_metadata
@pytest.mark.parametrize(
"api_key_or_token", [dict(api_key="bogus"), dict(iam_token="bogus")]
)
@pytest.mark.parametrize(
"disable_logging",
[dict(), dict(disable_request_logging=True), dict(disable_request_logging=False)],
)
@mock.patch.dict(os.environ, {}, clear=True)
def test_doc_embedding_call(api_key_or_token: dict, disable_logging: dict) -> None:
absent_yandex_module_stub = MagicMock()
with mock.patch.dict(
"sys.modules",
{
YANDEX_MODULE_NAME: absent_yandex_module_stub,
YANDEX_MODULE_NAME2: absent_yandex_module_stub,
"grpc": MagicMock(),
},
):
stub = absent_yandex_module_stub.EmbeddingsServiceStub
request_stub = absent_yandex_module_stub.TextEmbeddingRequest
args = {"folder_id": "fldr", **api_key_or_token, **disable_logging}
ygpt = YandexGPTEmbeddings(**args)
grpc_call_mock = stub.return_value.TextEmbedding
foo_emb = mock.Mock()
foo_emb.embedding = [1, 2, 3]
bar_emb = mock.Mock()
bar_emb.embedding = [4, 5, 6]
grpc_call_mock.side_effect = [foo_emb, bar_emb]
act_emb = ygpt.embed_documents(["foo", "bar"])
assert act_emb == [[1, 2, 3], [4, 5, 6]]
assert len(grpc_call_mock.call_args_list) == 2
for i, txt in enumerate(["foo", "bar"]):
act_model_uri = request_stub.call_args_list[i].kwargs["model_uri"]
act_text = request_stub.call_args_list[i].kwargs["text"]
call_args = grpc_call_mock.call_args_list[i]
act_metadata = call_args.kwargs["metadata"]
assert "fldr" in act_model_uri
assert "query" not in act_model_uri
assert "doc" in act_model_uri
assert act_text == txt
assert act_metadata
assert len(act_metadata) > 0
if disable_logging.get("disable_request_logging"):
assert ("x-data-logging-enabled", "false") in call_args.kwargs[
"metadata"
]

View File

@@ -1,3 +1,7 @@
import os
from unittest import mock
from unittest.mock import MagicMock
import pytest
from langchain_community.llms.yandex import YandexGPT
@@ -36,3 +40,53 @@ def test_yandexgpt_invalid_model_params() -> None:
api_key="your_api_key", # type: ignore[arg-type]
model_uri="",
)
@pytest.mark.parametrize(
"api_key_or_token", [dict(api_key="bogus"), dict(iam_token="bogus")]
)
@pytest.mark.parametrize(
"disable_logging",
[dict(), dict(disable_request_logging=True), dict(disable_request_logging=False)],
)
@mock.patch.dict(os.environ, {}, clear=True)
def test_completion_call(api_key_or_token: dict, disable_logging: dict) -> None:
absent_yandex_module_stub = MagicMock()
grpc_mock = MagicMock()
with mock.patch.dict(
"sys.modules",
{
"yandex.cloud.ai.foundation_models.v1."
"text_common_pb2": absent_yandex_module_stub,
"yandex.cloud.ai.foundation_models.v1.text_generation."
"text_generation_service_pb2": absent_yandex_module_stub,
"yandex.cloud.ai.foundation_models.v1.text_generation."
"text_generation_service_pb2_grpc": absent_yandex_module_stub,
"grpc": grpc_mock,
},
):
grpc_mock.RpcError = Exception
stub = absent_yandex_module_stub.TextGenerationServiceStub
request_stub = absent_yandex_module_stub.CompletionRequest
msg_constructor_stub = absent_yandex_module_stub.Message
args = {"folder_id": "fldr", **api_key_or_token, **disable_logging}
ygpt = YandexGPT(**args)
grpc_call_mock = stub.return_value.Completion
msg_mock = mock.Mock()
msg_mock.message.text = "cmpltn"
res_mock = mock.Mock()
res_mock.alternatives = [msg_mock]
grpc_call_mock.return_value = [res_mock]
act_emb = ygpt.invoke("nomatter")
assert act_emb == "cmpltn"
assert len(grpc_call_mock.call_args_list) == 1
once_called_args = grpc_call_mock.call_args_list[0]
act_model_uri = request_stub.call_args_list[0].kwargs["model_uri"]
act_text = msg_constructor_stub.call_args_list[0].kwargs["text"]
act_metadata = once_called_args.kwargs["metadata"]
assert "fldr" in act_model_uri
assert act_text == "nomatter"
assert act_metadata
assert len(act_metadata) > 0
if disable_logging.get("disable_request_logging"):
assert ("x-data-logging-enabled", "false") in act_metadata