mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 06:14:37 +00:00
it fixes two issues: ### YGPTs are broken #25575 ``` File ....conda/lib/python3.11/site-packages/langchain_community/embeddings/yandex.py:211, in _make_request(self, texts, **kwargs) .. --> 211 res = stub.TextEmbedding(request, metadata=self._grpc_metadata) # type: ignore[attr-defined] AttributeError: 'YandexGPTEmbeddings' object has no attribute '_grpc_metadata' ``` My gut feeling that #23841 is the cause. I have to drop leading underscore from `_grpc_metadata` for quickfix, but I just don't know how to do it _pydantic_ enough. ### minor issue: if we use `api_key`, which is not the best practice the code fails with ``` File ~/git/...../python3.11/site-packages/langchain_community/embeddings/yandex.py:119, in YandexGPTEmbeddings.validate_environment(cls, values) ... AttributeError: 'tuple' object has no attribute 'append' ``` - Added new integration test. But it requires YGPT env available and active account. I don't know how int tests dis\enabled in CI. - added small unit tests with mocks. Should be fine. --------- Co-authored-by: mikhail-khludnev <mikhail_khludnev@rntgroup.com>
This commit is contained in:
parent
850bf89e48
commit
a017f49fd3
@ -170,7 +170,7 @@ def _make_request(
|
|||||||
messages=[Message(**message) for message in message_history],
|
messages=[Message(**message) for message in message_history],
|
||||||
)
|
)
|
||||||
stub = TextGenerationServiceStub(channel)
|
stub = TextGenerationServiceStub(channel)
|
||||||
res = stub.Completion(request, metadata=self._grpc_metadata)
|
res = stub.Completion(request, metadata=self.grpc_metadata)
|
||||||
return list(res)[0].alternatives[0].message.text
|
return list(res)[0].alternatives[0].message.text
|
||||||
|
|
||||||
|
|
||||||
@ -229,7 +229,7 @@ async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> st
|
|||||||
messages=[Message(**message) for message in message_history],
|
messages=[Message(**message) for message in message_history],
|
||||||
)
|
)
|
||||||
stub = TextGenerationAsyncServiceStub(channel)
|
stub = TextGenerationAsyncServiceStub(channel)
|
||||||
operation = await stub.Completion(request, metadata=self._grpc_metadata)
|
operation = await stub.Completion(request, metadata=self.grpc_metadata)
|
||||||
async with grpc.aio.secure_channel(
|
async with grpc.aio.secure_channel(
|
||||||
operation_api_url, channel_credentials
|
operation_api_url, channel_credentials
|
||||||
) as operation_channel:
|
) as operation_channel:
|
||||||
@ -239,7 +239,7 @@ async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> st
|
|||||||
operation_request = GetOperationRequest(operation_id=operation.id)
|
operation_request = GetOperationRequest(operation_id=operation.id)
|
||||||
operation = await operation_stub.Get(
|
operation = await operation_stub.Get(
|
||||||
operation_request,
|
operation_request,
|
||||||
metadata=self._grpc_metadata,
|
metadata=self.grpc_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
completion_response = CompletionResponse()
|
completion_response = CompletionResponse()
|
||||||
|
@ -69,7 +69,7 @@ class YandexGPTEmbeddings(BaseModel, Embeddings):
|
|||||||
disable_request_logging: bool = False
|
disable_request_logging: bool = False
|
||||||
"""YandexGPT API logs all request data by default.
|
"""YandexGPT API logs all request data by default.
|
||||||
If you provide personal data, confidential information, disable logging."""
|
If you provide personal data, confidential information, disable logging."""
|
||||||
_grpc_metadata: Sequence
|
grpc_metadata: Sequence
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
allow_population_by_field_name = True
|
allow_population_by_field_name = True
|
||||||
@ -91,15 +91,15 @@ class YandexGPTEmbeddings(BaseModel, Embeddings):
|
|||||||
if api_key.get_secret_value() == "" and iam_token.get_secret_value() == "":
|
if api_key.get_secret_value() == "" and iam_token.get_secret_value() == "":
|
||||||
raise ValueError("Either 'YC_API_KEY' or 'YC_IAM_TOKEN' must be provided.")
|
raise ValueError("Either 'YC_API_KEY' or 'YC_IAM_TOKEN' must be provided.")
|
||||||
if values["iam_token"]:
|
if values["iam_token"]:
|
||||||
values["_grpc_metadata"] = [
|
values["grpc_metadata"] = [
|
||||||
("authorization", f"Bearer {values['iam_token'].get_secret_value()}")
|
("authorization", f"Bearer {values['iam_token'].get_secret_value()}")
|
||||||
]
|
]
|
||||||
if values["folder_id"]:
|
if values["folder_id"]:
|
||||||
values["_grpc_metadata"].append(("x-folder-id", values["folder_id"]))
|
values["grpc_metadata"].append(("x-folder-id", values["folder_id"]))
|
||||||
else:
|
else:
|
||||||
values["_grpc_metadata"] = (
|
values["grpc_metadata"] = [
|
||||||
("authorization", f"Api-Key {values['api_key'].get_secret_value()}"),
|
("authorization", f"Api-Key {values['api_key'].get_secret_value()}"),
|
||||||
)
|
]
|
||||||
|
|
||||||
if not values.get("doc_model_uri"):
|
if not values.get("doc_model_uri"):
|
||||||
if values["folder_id"] == "":
|
if values["folder_id"] == "":
|
||||||
@ -114,7 +114,7 @@ class YandexGPTEmbeddings(BaseModel, Embeddings):
|
|||||||
f"emb://{values['folder_id']}/{values['model_name']}/{values['model_version']}"
|
f"emb://{values['folder_id']}/{values['model_name']}/{values['model_version']}"
|
||||||
)
|
)
|
||||||
if values["disable_request_logging"]:
|
if values["disable_request_logging"]:
|
||||||
values["_grpc_metadata"].append(
|
values["grpc_metadata"].append(
|
||||||
(
|
(
|
||||||
"x-data-logging-enabled",
|
"x-data-logging-enabled",
|
||||||
"false",
|
"false",
|
||||||
@ -206,7 +206,7 @@ def _make_request(self: YandexGPTEmbeddings, texts: List[str], **kwargs): # typ
|
|||||||
for text in texts:
|
for text in texts:
|
||||||
request = TextEmbeddingRequest(model_uri=model_uri, text=text)
|
request = TextEmbeddingRequest(model_uri=model_uri, text=text)
|
||||||
stub = EmbeddingsServiceStub(channel)
|
stub = EmbeddingsServiceStub(channel)
|
||||||
res = stub.TextEmbedding(request, metadata=self._grpc_metadata) # type: ignore[attr-defined]
|
res = stub.TextEmbedding(request, metadata=self.grpc_metadata) # type: ignore[attr-defined]
|
||||||
result.append(list(res.embedding))
|
result.append(list(res.embedding))
|
||||||
time.sleep(self.sleep_interval)
|
time.sleep(self.sleep_interval)
|
||||||
|
|
||||||
|
@ -57,7 +57,7 @@ class _BaseYandexGPT(Serializable):
|
|||||||
disable_request_logging: bool = False
|
disable_request_logging: bool = False
|
||||||
"""YandexGPT API logs all request data by default.
|
"""YandexGPT API logs all request data by default.
|
||||||
If you provide personal data, confidential information, disable logging."""
|
If you provide personal data, confidential information, disable logging."""
|
||||||
_grpc_metadata: Optional[Sequence] = None
|
grpc_metadata: Optional[Sequence] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
@ -92,15 +92,15 @@ class _BaseYandexGPT(Serializable):
|
|||||||
raise ValueError("Either 'YC_API_KEY' or 'YC_IAM_TOKEN' must be provided.")
|
raise ValueError("Either 'YC_API_KEY' or 'YC_IAM_TOKEN' must be provided.")
|
||||||
|
|
||||||
if values["iam_token"]:
|
if values["iam_token"]:
|
||||||
values["_grpc_metadata"] = [
|
values["grpc_metadata"] = [
|
||||||
("authorization", f"Bearer {values['iam_token'].get_secret_value()}")
|
("authorization", f"Bearer {values['iam_token'].get_secret_value()}")
|
||||||
]
|
]
|
||||||
if values["folder_id"]:
|
if values["folder_id"]:
|
||||||
values["_grpc_metadata"].append(("x-folder-id", values["folder_id"]))
|
values["grpc_metadata"].append(("x-folder-id", values["folder_id"]))
|
||||||
else:
|
else:
|
||||||
values["_grpc_metadata"] = (
|
values["grpc_metadata"] = [
|
||||||
("authorization", f"Api-Key {values['api_key'].get_secret_value()}"),
|
("authorization", f"Api-Key {values['api_key'].get_secret_value()}"),
|
||||||
)
|
]
|
||||||
if values["model_uri"] == "" and values["folder_id"] == "":
|
if values["model_uri"] == "" and values["folder_id"] == "":
|
||||||
raise ValueError("Either 'model_uri' or 'folder_id' must be provided.")
|
raise ValueError("Either 'model_uri' or 'folder_id' must be provided.")
|
||||||
if not values["model_uri"]:
|
if not values["model_uri"]:
|
||||||
@ -108,7 +108,7 @@ class _BaseYandexGPT(Serializable):
|
|||||||
f"gpt://{values['folder_id']}/{values['model_name']}/{values['model_version']}"
|
f"gpt://{values['folder_id']}/{values['model_name']}/{values['model_version']}"
|
||||||
)
|
)
|
||||||
if values["disable_request_logging"]:
|
if values["disable_request_logging"]:
|
||||||
values["_grpc_metadata"].append(
|
values["grpc_metadata"].append(
|
||||||
(
|
(
|
||||||
"x-data-logging-enabled",
|
"x-data-logging-enabled",
|
||||||
"false",
|
"false",
|
||||||
@ -235,7 +235,7 @@ def _make_request(
|
|||||||
messages=[Message(role="user", text=prompt)],
|
messages=[Message(role="user", text=prompt)],
|
||||||
)
|
)
|
||||||
stub = TextGenerationServiceStub(channel)
|
stub = TextGenerationServiceStub(channel)
|
||||||
res = stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined]
|
res = stub.Completion(request, metadata=self.grpc_metadata) # type: ignore[attr-defined]
|
||||||
return list(res)[0].alternatives[0].message.text
|
return list(res)[0].alternatives[0].message.text
|
||||||
|
|
||||||
|
|
||||||
@ -291,7 +291,7 @@ async def _amake_request(self: YandexGPT, prompt: str) -> str:
|
|||||||
messages=[Message(role="user", text=prompt)],
|
messages=[Message(role="user", text=prompt)],
|
||||||
)
|
)
|
||||||
stub = TextGenerationAsyncServiceStub(channel)
|
stub = TextGenerationAsyncServiceStub(channel)
|
||||||
operation = await stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined]
|
operation = await stub.Completion(request, metadata=self.grpc_metadata) # type: ignore[attr-defined]
|
||||||
async with grpc.aio.secure_channel(
|
async with grpc.aio.secure_channel(
|
||||||
operation_api_url, channel_credentials
|
operation_api_url, channel_credentials
|
||||||
) as operation_channel:
|
) as operation_channel:
|
||||||
@ -301,7 +301,7 @@ async def _amake_request(self: YandexGPT, prompt: str) -> str:
|
|||||||
operation_request = GetOperationRequest(operation_id=operation.id)
|
operation_request = GetOperationRequest(operation_id=operation.id)
|
||||||
operation = await operation_stub.Get(
|
operation = await operation_stub.Get(
|
||||||
operation_request,
|
operation_request,
|
||||||
metadata=self._grpc_metadata, # type: ignore[attr-defined]
|
metadata=self.grpc_metadata, # type: ignore[attr-defined]
|
||||||
)
|
)
|
||||||
|
|
||||||
completion_response = CompletionResponse()
|
completion_response = CompletionResponse()
|
||||||
|
@ -0,0 +1,31 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain_community.embeddings.yandex import YandexGPTEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"constructor_args",
|
||||||
|
[
|
||||||
|
dict(),
|
||||||
|
dict(disable_request_logging=True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# @pytest.mark.scheduled - idk what it means
|
||||||
|
# requires YC_* env and active service
|
||||||
|
def test_yandex_embedding(constructor_args: dict) -> None:
|
||||||
|
documents = ["exactly same", "exactly same", "different"]
|
||||||
|
embedding = YandexGPTEmbeddings(**constructor_args)
|
||||||
|
doc_outputs = embedding.embed_documents(documents)
|
||||||
|
assert len(doc_outputs) == 3
|
||||||
|
for i in range(3):
|
||||||
|
assert len(doc_outputs[i]) >= 256 # there are many dims
|
||||||
|
assert len(doc_outputs[0]) == len(doc_outputs[i]) # dims are te same
|
||||||
|
assert doc_outputs[0] == doc_outputs[1] # same input, same embeddings
|
||||||
|
assert doc_outputs[2] != doc_outputs[1] # different input, different embeddings
|
||||||
|
|
||||||
|
qry_output = embedding.embed_query(documents[0])
|
||||||
|
assert len(qry_output) >= 256
|
||||||
|
assert len(doc_outputs[0]) == len(
|
||||||
|
qry_output
|
||||||
|
) # query and doc models have same dimensions
|
||||||
|
assert doc_outputs[0] != qry_output # query and doc models are different
|
92
libs/community/tests/unit_tests/chat_models/test_yandex.py
Normal file
92
libs/community/tests/unit_tests/chat_models/test_yandex.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
import os
|
||||||
|
from unittest import mock
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain_community.chat_models.yandex import ChatYandexGPT
|
||||||
|
|
||||||
|
|
||||||
|
def test_yandexgpt_initialization() -> None:
|
||||||
|
llm = ChatYandexGPT(
|
||||||
|
iam_token="your_iam_token", # type: ignore[arg-type]
|
||||||
|
api_key="your_api_key", # type: ignore[arg-type]
|
||||||
|
folder_id="your_folder_id",
|
||||||
|
)
|
||||||
|
assert llm.model_name == "yandexgpt-lite"
|
||||||
|
assert llm.model_uri.startswith("gpt://your_folder_id/yandexgpt-lite/")
|
||||||
|
|
||||||
|
|
||||||
|
def test_yandexgpt_model_params() -> None:
|
||||||
|
llm = ChatYandexGPT(
|
||||||
|
model_name="custom-model",
|
||||||
|
model_version="v1",
|
||||||
|
iam_token="your_iam_token", # type: ignore[arg-type]
|
||||||
|
api_key="your_api_key", # type: ignore[arg-type]
|
||||||
|
folder_id="your_folder_id",
|
||||||
|
)
|
||||||
|
assert llm.model_name == "custom-model"
|
||||||
|
assert llm.model_version == "v1"
|
||||||
|
assert llm.iam_token.get_secret_value() == "your_iam_token"
|
||||||
|
assert llm.model_uri == "gpt://your_folder_id/custom-model/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_yandexgpt_invalid_model_params() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ChatYandexGPT(model_uri="", iam_token="your_iam_token") # type: ignore[arg-type]
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ChatYandexGPT(
|
||||||
|
iam_token="", # type: ignore[arg-type]
|
||||||
|
api_key="your_api_key", # type: ignore[arg-type]
|
||||||
|
model_uri="",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"api_key_or_token", [dict(api_key="bogus"), dict(iam_token="bogus")]
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"disable_logging",
|
||||||
|
[dict(), dict(disable_request_logging=True), dict(disable_request_logging=False)],
|
||||||
|
)
|
||||||
|
@mock.patch.dict(os.environ, {}, clear=True)
|
||||||
|
def test_completion_call(api_key_or_token: dict, disable_logging: dict) -> None:
|
||||||
|
absent_yandex_module_stub = MagicMock()
|
||||||
|
grpc_mock = MagicMock()
|
||||||
|
with mock.patch.dict(
|
||||||
|
"sys.modules",
|
||||||
|
{
|
||||||
|
"yandex.cloud.ai.foundation_models.v1."
|
||||||
|
"text_common_pb2": absent_yandex_module_stub,
|
||||||
|
"yandex.cloud.ai.foundation_models.v1.text_generation."
|
||||||
|
"text_generation_service_pb2": absent_yandex_module_stub,
|
||||||
|
"yandex.cloud.ai.foundation_models.v1.text_generation."
|
||||||
|
"text_generation_service_pb2_grpc": absent_yandex_module_stub,
|
||||||
|
"grpc": grpc_mock,
|
||||||
|
},
|
||||||
|
):
|
||||||
|
grpc_mock.RpcError = Exception
|
||||||
|
stub = absent_yandex_module_stub.TextGenerationServiceStub
|
||||||
|
request_stub = absent_yandex_module_stub.CompletionRequest
|
||||||
|
msg_constructor_stub = absent_yandex_module_stub.Message
|
||||||
|
args = {"folder_id": "fldr", **api_key_or_token, **disable_logging}
|
||||||
|
ygpt = ChatYandexGPT(**args)
|
||||||
|
grpc_call_mock = stub.return_value.Completion
|
||||||
|
msg_mock = mock.Mock()
|
||||||
|
msg_mock.message.text = "cmpltn"
|
||||||
|
res_mock = mock.Mock()
|
||||||
|
res_mock.alternatives = [msg_mock]
|
||||||
|
grpc_call_mock.return_value = [res_mock]
|
||||||
|
act_emb = ygpt.invoke("nomatter")
|
||||||
|
assert act_emb.content == "cmpltn"
|
||||||
|
assert len(grpc_call_mock.call_args_list) == 1
|
||||||
|
once_called_args = grpc_call_mock.call_args_list[0]
|
||||||
|
act_model_uri = request_stub.call_args_list[0].kwargs["model_uri"]
|
||||||
|
act_text = msg_constructor_stub.call_args_list[0].kwargs["text"]
|
||||||
|
act_metadata = once_called_args.kwargs["metadata"]
|
||||||
|
assert "fldr" in act_model_uri
|
||||||
|
assert act_text == "nomatter"
|
||||||
|
assert act_metadata
|
||||||
|
assert len(act_metadata) > 0
|
||||||
|
if disable_logging.get("disable_request_logging"):
|
||||||
|
assert ("x-data-logging-enabled", "false") in act_metadata
|
@ -1,10 +1,21 @@
|
|||||||
import os
|
import os
|
||||||
|
from unittest import mock
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from langchain_community.embeddings import YandexGPTEmbeddings
|
from langchain_community.embeddings import YandexGPTEmbeddings
|
||||||
|
|
||||||
|
YANDEX_MODULE_NAME2 = (
|
||||||
|
"yandex.cloud.ai.foundation_models.v1.embedding.embedding_service_pb2_grpc"
|
||||||
|
)
|
||||||
|
YANDEX_MODULE_NAME = (
|
||||||
|
"yandex.cloud.ai.foundation_models.v1.embedding.embedding_service_pb2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch.dict(os.environ, {"YC_API_KEY": "foo"}, clear=True)
|
||||||
def test_init() -> None:
|
def test_init() -> None:
|
||||||
os.environ["YC_API_KEY"] = "foo"
|
|
||||||
models = [
|
models = [
|
||||||
YandexGPTEmbeddings(folder_id="bar"), # type: ignore[call-arg]
|
YandexGPTEmbeddings(folder_id="bar"), # type: ignore[call-arg]
|
||||||
YandexGPTEmbeddings( # type: ignore[call-arg]
|
YandexGPTEmbeddings( # type: ignore[call-arg]
|
||||||
@ -22,3 +33,92 @@ def test_init() -> None:
|
|||||||
assert embeddings.doc_model_uri == "emb://bar/text-search-doc/latest"
|
assert embeddings.doc_model_uri == "emb://bar/text-search-doc/latest"
|
||||||
assert embeddings.model_name == "text-search-query"
|
assert embeddings.model_name == "text-search-query"
|
||||||
assert embeddings.doc_model_name == "text-search-doc"
|
assert embeddings.doc_model_name == "text-search-doc"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"api_key_or_token", [dict(api_key="bogus"), dict(iam_token="bogus")]
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"disable_logging",
|
||||||
|
[dict(), dict(disable_request_logging=True), dict(disable_request_logging=False)],
|
||||||
|
)
|
||||||
|
@mock.patch.dict(os.environ, {}, clear=True)
|
||||||
|
def test_query_embedding_call(api_key_or_token: dict, disable_logging: dict) -> None:
|
||||||
|
absent_yandex_module_stub = MagicMock()
|
||||||
|
with mock.patch.dict(
|
||||||
|
"sys.modules",
|
||||||
|
{
|
||||||
|
YANDEX_MODULE_NAME: absent_yandex_module_stub,
|
||||||
|
YANDEX_MODULE_NAME2: absent_yandex_module_stub,
|
||||||
|
"grpc": MagicMock(),
|
||||||
|
},
|
||||||
|
):
|
||||||
|
stub = absent_yandex_module_stub.EmbeddingsServiceStub
|
||||||
|
request_stub = absent_yandex_module_stub.TextEmbeddingRequest
|
||||||
|
args = {"folder_id": "fldr", **api_key_or_token, **disable_logging}
|
||||||
|
ygpt = YandexGPTEmbeddings(**args)
|
||||||
|
grpc_call_mock = stub.return_value.TextEmbedding
|
||||||
|
grpc_call_mock.return_value.embedding = [1, 2, 3]
|
||||||
|
act_emb = ygpt.embed_query("nomatter")
|
||||||
|
assert act_emb == [1, 2, 3]
|
||||||
|
assert len(grpc_call_mock.call_args_list) == 1
|
||||||
|
once_called_args = grpc_call_mock.call_args_list[0]
|
||||||
|
act_model_uri = request_stub.call_args_list[0].kwargs["model_uri"]
|
||||||
|
assert "fldr" in act_model_uri
|
||||||
|
assert "query" in act_model_uri
|
||||||
|
assert "doc" not in act_model_uri
|
||||||
|
act_text = request_stub.call_args_list[0].kwargs["text"]
|
||||||
|
assert act_text == "nomatter"
|
||||||
|
act_metadata = once_called_args.kwargs["metadata"]
|
||||||
|
assert act_metadata
|
||||||
|
assert len(act_metadata) > 0
|
||||||
|
if disable_logging.get("disable_request_logging"):
|
||||||
|
assert ("x-data-logging-enabled", "false") in act_metadata
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"api_key_or_token", [dict(api_key="bogus"), dict(iam_token="bogus")]
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"disable_logging",
|
||||||
|
[dict(), dict(disable_request_logging=True), dict(disable_request_logging=False)],
|
||||||
|
)
|
||||||
|
@mock.patch.dict(os.environ, {}, clear=True)
|
||||||
|
def test_doc_embedding_call(api_key_or_token: dict, disable_logging: dict) -> None:
|
||||||
|
absent_yandex_module_stub = MagicMock()
|
||||||
|
with mock.patch.dict(
|
||||||
|
"sys.modules",
|
||||||
|
{
|
||||||
|
YANDEX_MODULE_NAME: absent_yandex_module_stub,
|
||||||
|
YANDEX_MODULE_NAME2: absent_yandex_module_stub,
|
||||||
|
"grpc": MagicMock(),
|
||||||
|
},
|
||||||
|
):
|
||||||
|
stub = absent_yandex_module_stub.EmbeddingsServiceStub
|
||||||
|
request_stub = absent_yandex_module_stub.TextEmbeddingRequest
|
||||||
|
args = {"folder_id": "fldr", **api_key_or_token, **disable_logging}
|
||||||
|
ygpt = YandexGPTEmbeddings(**args)
|
||||||
|
grpc_call_mock = stub.return_value.TextEmbedding
|
||||||
|
foo_emb = mock.Mock()
|
||||||
|
foo_emb.embedding = [1, 2, 3]
|
||||||
|
bar_emb = mock.Mock()
|
||||||
|
bar_emb.embedding = [4, 5, 6]
|
||||||
|
grpc_call_mock.side_effect = [foo_emb, bar_emb]
|
||||||
|
act_emb = ygpt.embed_documents(["foo", "bar"])
|
||||||
|
assert act_emb == [[1, 2, 3], [4, 5, 6]]
|
||||||
|
assert len(grpc_call_mock.call_args_list) == 2
|
||||||
|
for i, txt in enumerate(["foo", "bar"]):
|
||||||
|
act_model_uri = request_stub.call_args_list[i].kwargs["model_uri"]
|
||||||
|
act_text = request_stub.call_args_list[i].kwargs["text"]
|
||||||
|
call_args = grpc_call_mock.call_args_list[i]
|
||||||
|
act_metadata = call_args.kwargs["metadata"]
|
||||||
|
assert "fldr" in act_model_uri
|
||||||
|
assert "query" not in act_model_uri
|
||||||
|
assert "doc" in act_model_uri
|
||||||
|
assert act_text == txt
|
||||||
|
assert act_metadata
|
||||||
|
assert len(act_metadata) > 0
|
||||||
|
if disable_logging.get("disable_request_logging"):
|
||||||
|
assert ("x-data-logging-enabled", "false") in call_args.kwargs[
|
||||||
|
"metadata"
|
||||||
|
]
|
||||||
|
@ -1,3 +1,7 @@
|
|||||||
|
import os
|
||||||
|
from unittest import mock
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain_community.llms.yandex import YandexGPT
|
from langchain_community.llms.yandex import YandexGPT
|
||||||
@ -36,3 +40,53 @@ def test_yandexgpt_invalid_model_params() -> None:
|
|||||||
api_key="your_api_key", # type: ignore[arg-type]
|
api_key="your_api_key", # type: ignore[arg-type]
|
||||||
model_uri="",
|
model_uri="",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"api_key_or_token", [dict(api_key="bogus"), dict(iam_token="bogus")]
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"disable_logging",
|
||||||
|
[dict(), dict(disable_request_logging=True), dict(disable_request_logging=False)],
|
||||||
|
)
|
||||||
|
@mock.patch.dict(os.environ, {}, clear=True)
|
||||||
|
def test_completion_call(api_key_or_token: dict, disable_logging: dict) -> None:
|
||||||
|
absent_yandex_module_stub = MagicMock()
|
||||||
|
grpc_mock = MagicMock()
|
||||||
|
with mock.patch.dict(
|
||||||
|
"sys.modules",
|
||||||
|
{
|
||||||
|
"yandex.cloud.ai.foundation_models.v1."
|
||||||
|
"text_common_pb2": absent_yandex_module_stub,
|
||||||
|
"yandex.cloud.ai.foundation_models.v1.text_generation."
|
||||||
|
"text_generation_service_pb2": absent_yandex_module_stub,
|
||||||
|
"yandex.cloud.ai.foundation_models.v1.text_generation."
|
||||||
|
"text_generation_service_pb2_grpc": absent_yandex_module_stub,
|
||||||
|
"grpc": grpc_mock,
|
||||||
|
},
|
||||||
|
):
|
||||||
|
grpc_mock.RpcError = Exception
|
||||||
|
stub = absent_yandex_module_stub.TextGenerationServiceStub
|
||||||
|
request_stub = absent_yandex_module_stub.CompletionRequest
|
||||||
|
msg_constructor_stub = absent_yandex_module_stub.Message
|
||||||
|
args = {"folder_id": "fldr", **api_key_or_token, **disable_logging}
|
||||||
|
ygpt = YandexGPT(**args)
|
||||||
|
grpc_call_mock = stub.return_value.Completion
|
||||||
|
msg_mock = mock.Mock()
|
||||||
|
msg_mock.message.text = "cmpltn"
|
||||||
|
res_mock = mock.Mock()
|
||||||
|
res_mock.alternatives = [msg_mock]
|
||||||
|
grpc_call_mock.return_value = [res_mock]
|
||||||
|
act_emb = ygpt.invoke("nomatter")
|
||||||
|
assert act_emb == "cmpltn"
|
||||||
|
assert len(grpc_call_mock.call_args_list) == 1
|
||||||
|
once_called_args = grpc_call_mock.call_args_list[0]
|
||||||
|
act_model_uri = request_stub.call_args_list[0].kwargs["model_uri"]
|
||||||
|
act_text = msg_constructor_stub.call_args_list[0].kwargs["text"]
|
||||||
|
act_metadata = once_called_args.kwargs["metadata"]
|
||||||
|
assert "fldr" in act_model_uri
|
||||||
|
assert act_text == "nomatter"
|
||||||
|
assert act_metadata
|
||||||
|
assert len(act_metadata) > 0
|
||||||
|
if disable_logging.get("disable_request_logging"):
|
||||||
|
assert ("x-data-logging-enabled", "false") in act_metadata
|
||||||
|
Loading…
Reference in New Issue
Block a user