diff --git a/libs/community/langchain_community/chat_models/yandex.py b/libs/community/langchain_community/chat_models/yandex.py index d6b3eb66dc8..02aed41650c 100644 --- a/libs/community/langchain_community/chat_models/yandex.py +++ b/libs/community/langchain_community/chat_models/yandex.py @@ -170,7 +170,7 @@ def _make_request( messages=[Message(**message) for message in message_history], ) stub = TextGenerationServiceStub(channel) - res = stub.Completion(request, metadata=self._grpc_metadata) + res = stub.Completion(request, metadata=self.grpc_metadata) return list(res)[0].alternatives[0].message.text @@ -229,7 +229,7 @@ async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> st messages=[Message(**message) for message in message_history], ) stub = TextGenerationAsyncServiceStub(channel) - operation = await stub.Completion(request, metadata=self._grpc_metadata) + operation = await stub.Completion(request, metadata=self.grpc_metadata) async with grpc.aio.secure_channel( operation_api_url, channel_credentials ) as operation_channel: @@ -239,7 +239,7 @@ async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> st operation_request = GetOperationRequest(operation_id=operation.id) operation = await operation_stub.Get( operation_request, - metadata=self._grpc_metadata, + metadata=self.grpc_metadata, ) completion_response = CompletionResponse() diff --git a/libs/community/langchain_community/embeddings/yandex.py b/libs/community/langchain_community/embeddings/yandex.py index 7c662d49a06..dbddf245234 100644 --- a/libs/community/langchain_community/embeddings/yandex.py +++ b/libs/community/langchain_community/embeddings/yandex.py @@ -69,7 +69,7 @@ class YandexGPTEmbeddings(BaseModel, Embeddings): disable_request_logging: bool = False """YandexGPT API logs all request data by default. If you provide personal data, confidential information, disable logging.""" - _grpc_metadata: Sequence + grpc_metadata: Sequence class Config: allow_population_by_field_name = True @@ -91,15 +91,15 @@ class YandexGPTEmbeddings(BaseModel, Embeddings): if api_key.get_secret_value() == "" and iam_token.get_secret_value() == "": raise ValueError("Either 'YC_API_KEY' or 'YC_IAM_TOKEN' must be provided.") if values["iam_token"]: - values["_grpc_metadata"] = [ + values["grpc_metadata"] = [ ("authorization", f"Bearer {values['iam_token'].get_secret_value()}") ] if values["folder_id"]: - values["_grpc_metadata"].append(("x-folder-id", values["folder_id"])) + values["grpc_metadata"].append(("x-folder-id", values["folder_id"])) else: - values["_grpc_metadata"] = ( + values["grpc_metadata"] = [ ("authorization", f"Api-Key {values['api_key'].get_secret_value()}"), - ) + ] if not values.get("doc_model_uri"): if values["folder_id"] == "": @@ -114,7 +114,7 @@ class YandexGPTEmbeddings(BaseModel, Embeddings): f"emb://{values['folder_id']}/{values['model_name']}/{values['model_version']}" ) if values["disable_request_logging"]: - values["_grpc_metadata"].append( + values["grpc_metadata"].append( ( "x-data-logging-enabled", "false", @@ -206,7 +206,7 @@ def _make_request(self: YandexGPTEmbeddings, texts: List[str], **kwargs): # typ for text in texts: request = TextEmbeddingRequest(model_uri=model_uri, text=text) stub = EmbeddingsServiceStub(channel) - res = stub.TextEmbedding(request, metadata=self._grpc_metadata) # type: ignore[attr-defined] + res = stub.TextEmbedding(request, metadata=self.grpc_metadata) # type: ignore[attr-defined] result.append(list(res.embedding)) time.sleep(self.sleep_interval) diff --git a/libs/community/langchain_community/llms/yandex.py b/libs/community/langchain_community/llms/yandex.py index a29b0379441..b8761acd62b 100644 --- a/libs/community/langchain_community/llms/yandex.py +++ b/libs/community/langchain_community/llms/yandex.py @@ -57,7 +57,7 @@ class _BaseYandexGPT(Serializable): disable_request_logging: bool = False """YandexGPT API logs all request data by default. If you provide personal data, confidential information, disable logging.""" - _grpc_metadata: Optional[Sequence] = None + grpc_metadata: Optional[Sequence] = None @property def _llm_type(self) -> str: @@ -92,15 +92,15 @@ class _BaseYandexGPT(Serializable): raise ValueError("Either 'YC_API_KEY' or 'YC_IAM_TOKEN' must be provided.") if values["iam_token"]: - values["_grpc_metadata"] = [ + values["grpc_metadata"] = [ ("authorization", f"Bearer {values['iam_token'].get_secret_value()}") ] if values["folder_id"]: - values["_grpc_metadata"].append(("x-folder-id", values["folder_id"])) + values["grpc_metadata"].append(("x-folder-id", values["folder_id"])) else: - values["_grpc_metadata"] = ( + values["grpc_metadata"] = [ ("authorization", f"Api-Key {values['api_key'].get_secret_value()}"), - ) + ] if values["model_uri"] == "" and values["folder_id"] == "": raise ValueError("Either 'model_uri' or 'folder_id' must be provided.") if not values["model_uri"]: @@ -108,7 +108,7 @@ class _BaseYandexGPT(Serializable): f"gpt://{values['folder_id']}/{values['model_name']}/{values['model_version']}" ) if values["disable_request_logging"]: - values["_grpc_metadata"].append( + values["grpc_metadata"].append( ( "x-data-logging-enabled", "false", @@ -235,7 +235,7 @@ def _make_request( messages=[Message(role="user", text=prompt)], ) stub = TextGenerationServiceStub(channel) - res = stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined] + res = stub.Completion(request, metadata=self.grpc_metadata) # type: ignore[attr-defined] return list(res)[0].alternatives[0].message.text @@ -291,7 +291,7 @@ async def _amake_request(self: YandexGPT, prompt: str) -> str: messages=[Message(role="user", text=prompt)], ) stub = TextGenerationAsyncServiceStub(channel) - operation = await stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined] + operation = await stub.Completion(request, metadata=self.grpc_metadata) # type: ignore[attr-defined] async with grpc.aio.secure_channel( operation_api_url, channel_credentials ) as operation_channel: @@ -301,7 +301,7 @@ async def _amake_request(self: YandexGPT, prompt: str) -> str: operation_request = GetOperationRequest(operation_id=operation.id) operation = await operation_stub.Get( operation_request, - metadata=self._grpc_metadata, # type: ignore[attr-defined] + metadata=self.grpc_metadata, # type: ignore[attr-defined] ) completion_response = CompletionResponse() diff --git a/libs/community/tests/integration_tests/embeddings/test_yandex.py b/libs/community/tests/integration_tests/embeddings/test_yandex.py new file mode 100644 index 00000000000..ff7f5e36bf6 --- /dev/null +++ b/libs/community/tests/integration_tests/embeddings/test_yandex.py @@ -0,0 +1,31 @@ +import pytest + +from langchain_community.embeddings.yandex import YandexGPTEmbeddings + + +@pytest.mark.parametrize( + "constructor_args", + [ + dict(), + dict(disable_request_logging=True), + ], +) +# @pytest.mark.scheduled - idk what it means +# requires YC_* env and active service +def test_yandex_embedding(constructor_args: dict) -> None: + documents = ["exactly same", "exactly same", "different"] + embedding = YandexGPTEmbeddings(**constructor_args) + doc_outputs = embedding.embed_documents(documents) + assert len(doc_outputs) == 3 + for i in range(3): + assert len(doc_outputs[i]) >= 256 # there are many dims + assert len(doc_outputs[0]) == len(doc_outputs[i]) # dims are te same + assert doc_outputs[0] == doc_outputs[1] # same input, same embeddings + assert doc_outputs[2] != doc_outputs[1] # different input, different embeddings + + qry_output = embedding.embed_query(documents[0]) + assert len(qry_output) >= 256 + assert len(doc_outputs[0]) == len( + qry_output + ) # query and doc models have same dimensions + assert doc_outputs[0] != qry_output # query and doc models are different diff --git a/libs/community/tests/unit_tests/chat_models/test_yandex.py b/libs/community/tests/unit_tests/chat_models/test_yandex.py new file mode 100644 index 00000000000..0dd741909ed --- /dev/null +++ b/libs/community/tests/unit_tests/chat_models/test_yandex.py @@ -0,0 +1,92 @@ +import os +from unittest import mock +from unittest.mock import MagicMock + +import pytest + +from langchain_community.chat_models.yandex import ChatYandexGPT + + +def test_yandexgpt_initialization() -> None: + llm = ChatYandexGPT( + iam_token="your_iam_token", # type: ignore[arg-type] + api_key="your_api_key", # type: ignore[arg-type] + folder_id="your_folder_id", + ) + assert llm.model_name == "yandexgpt-lite" + assert llm.model_uri.startswith("gpt://your_folder_id/yandexgpt-lite/") + + +def test_yandexgpt_model_params() -> None: + llm = ChatYandexGPT( + model_name="custom-model", + model_version="v1", + iam_token="your_iam_token", # type: ignore[arg-type] + api_key="your_api_key", # type: ignore[arg-type] + folder_id="your_folder_id", + ) + assert llm.model_name == "custom-model" + assert llm.model_version == "v1" + assert llm.iam_token.get_secret_value() == "your_iam_token" + assert llm.model_uri == "gpt://your_folder_id/custom-model/v1" + + +def test_yandexgpt_invalid_model_params() -> None: + with pytest.raises(ValueError): + ChatYandexGPT(model_uri="", iam_token="your_iam_token") # type: ignore[arg-type] + with pytest.raises(ValueError): + ChatYandexGPT( + iam_token="", # type: ignore[arg-type] + api_key="your_api_key", # type: ignore[arg-type] + model_uri="", + ) + + +@pytest.mark.parametrize( + "api_key_or_token", [dict(api_key="bogus"), dict(iam_token="bogus")] +) +@pytest.mark.parametrize( + "disable_logging", + [dict(), dict(disable_request_logging=True), dict(disable_request_logging=False)], +) +@mock.patch.dict(os.environ, {}, clear=True) +def test_completion_call(api_key_or_token: dict, disable_logging: dict) -> None: + absent_yandex_module_stub = MagicMock() + grpc_mock = MagicMock() + with mock.patch.dict( + "sys.modules", + { + "yandex.cloud.ai.foundation_models.v1." + "text_common_pb2": absent_yandex_module_stub, + "yandex.cloud.ai.foundation_models.v1.text_generation." + "text_generation_service_pb2": absent_yandex_module_stub, + "yandex.cloud.ai.foundation_models.v1.text_generation." + "text_generation_service_pb2_grpc": absent_yandex_module_stub, + "grpc": grpc_mock, + }, + ): + grpc_mock.RpcError = Exception + stub = absent_yandex_module_stub.TextGenerationServiceStub + request_stub = absent_yandex_module_stub.CompletionRequest + msg_constructor_stub = absent_yandex_module_stub.Message + args = {"folder_id": "fldr", **api_key_or_token, **disable_logging} + ygpt = ChatYandexGPT(**args) + grpc_call_mock = stub.return_value.Completion + msg_mock = mock.Mock() + msg_mock.message.text = "cmpltn" + res_mock = mock.Mock() + res_mock.alternatives = [msg_mock] + grpc_call_mock.return_value = [res_mock] + act_emb = ygpt.invoke("nomatter") + assert act_emb.content == "cmpltn" + assert len(grpc_call_mock.call_args_list) == 1 + once_called_args = grpc_call_mock.call_args_list[0] + act_model_uri = request_stub.call_args_list[0].kwargs["model_uri"] + act_text = msg_constructor_stub.call_args_list[0].kwargs["text"] + act_metadata = once_called_args.kwargs["metadata"] + assert "fldr" in act_model_uri + assert act_text == "nomatter" + assert act_metadata + assert len(act_metadata) > 0 + if disable_logging.get("disable_request_logging"): + assert ("x-data-logging-enabled", "false") in act_metadata diff --git a/libs/community/tests/unit_tests/embeddings/test_yandex.py b/libs/community/tests/unit_tests/embeddings/test_yandex.py index c681af05a28..76ab574f721 100644 --- a/libs/community/tests/unit_tests/embeddings/test_yandex.py +++ b/libs/community/tests/unit_tests/embeddings/test_yandex.py @@ -1,10 +1,21 @@ import os +from unittest import mock +from unittest.mock import MagicMock + +import pytest from langchain_community.embeddings import YandexGPTEmbeddings +YANDEX_MODULE_NAME2 = ( + "yandex.cloud.ai.foundation_models.v1.embedding.embedding_service_pb2_grpc" +) +YANDEX_MODULE_NAME = ( + "yandex.cloud.ai.foundation_models.v1.embedding.embedding_service_pb2" +) + +@mock.patch.dict(os.environ, {"YC_API_KEY": "foo"}, clear=True) def test_init() -> None: - os.environ["YC_API_KEY"] = "foo" models = [ YandexGPTEmbeddings(folder_id="bar"), # type: ignore[call-arg] YandexGPTEmbeddings( # type: ignore[call-arg] @@ -22,3 +33,92 @@ def test_init() -> None: assert embeddings.doc_model_uri == "emb://bar/text-search-doc/latest" assert embeddings.model_name == "text-search-query" assert embeddings.doc_model_name == "text-search-doc" + + +@pytest.mark.parametrize( + "api_key_or_token", [dict(api_key="bogus"), dict(iam_token="bogus")] +) +@pytest.mark.parametrize( + "disable_logging", + [dict(), dict(disable_request_logging=True), dict(disable_request_logging=False)], +) +@mock.patch.dict(os.environ, {}, clear=True) +def test_query_embedding_call(api_key_or_token: dict, disable_logging: dict) -> None: + absent_yandex_module_stub = MagicMock() + with mock.patch.dict( + "sys.modules", + { + YANDEX_MODULE_NAME: absent_yandex_module_stub, + YANDEX_MODULE_NAME2: absent_yandex_module_stub, + "grpc": MagicMock(), + }, + ): + stub = absent_yandex_module_stub.EmbeddingsServiceStub + request_stub = absent_yandex_module_stub.TextEmbeddingRequest + args = {"folder_id": "fldr", **api_key_or_token, **disable_logging} + ygpt = YandexGPTEmbeddings(**args) + grpc_call_mock = stub.return_value.TextEmbedding + grpc_call_mock.return_value.embedding = [1, 2, 3] + act_emb = ygpt.embed_query("nomatter") + assert act_emb == [1, 2, 3] + assert len(grpc_call_mock.call_args_list) == 1 + once_called_args = grpc_call_mock.call_args_list[0] + act_model_uri = request_stub.call_args_list[0].kwargs["model_uri"] + assert "fldr" in act_model_uri + assert "query" in act_model_uri + assert "doc" not in act_model_uri + act_text = request_stub.call_args_list[0].kwargs["text"] + assert act_text == "nomatter" + act_metadata = once_called_args.kwargs["metadata"] + assert act_metadata + assert len(act_metadata) > 0 + if disable_logging.get("disable_request_logging"): + assert ("x-data-logging-enabled", "false") in act_metadata + + +@pytest.mark.parametrize( + "api_key_or_token", [dict(api_key="bogus"), dict(iam_token="bogus")] +) +@pytest.mark.parametrize( + "disable_logging", + [dict(), dict(disable_request_logging=True), dict(disable_request_logging=False)], +) +@mock.patch.dict(os.environ, {}, clear=True) +def test_doc_embedding_call(api_key_or_token: dict, disable_logging: dict) -> None: + absent_yandex_module_stub = MagicMock() + with mock.patch.dict( + "sys.modules", + { + YANDEX_MODULE_NAME: absent_yandex_module_stub, + YANDEX_MODULE_NAME2: absent_yandex_module_stub, + "grpc": MagicMock(), + }, + ): + stub = absent_yandex_module_stub.EmbeddingsServiceStub + request_stub = absent_yandex_module_stub.TextEmbeddingRequest + args = {"folder_id": "fldr", **api_key_or_token, **disable_logging} + ygpt = YandexGPTEmbeddings(**args) + grpc_call_mock = stub.return_value.TextEmbedding + foo_emb = mock.Mock() + foo_emb.embedding = [1, 2, 3] + bar_emb = mock.Mock() + bar_emb.embedding = [4, 5, 6] + grpc_call_mock.side_effect = [foo_emb, bar_emb] + act_emb = ygpt.embed_documents(["foo", "bar"]) + assert act_emb == [[1, 2, 3], [4, 5, 6]] + assert len(grpc_call_mock.call_args_list) == 2 + for i, txt in enumerate(["foo", "bar"]): + act_model_uri = request_stub.call_args_list[i].kwargs["model_uri"] + act_text = request_stub.call_args_list[i].kwargs["text"] + call_args = grpc_call_mock.call_args_list[i] + act_metadata = call_args.kwargs["metadata"] + assert "fldr" in act_model_uri + assert "query" not in act_model_uri + assert "doc" in act_model_uri + assert act_text == txt + assert act_metadata + assert len(act_metadata) > 0 + if disable_logging.get("disable_request_logging"): + assert ("x-data-logging-enabled", "false") in call_args.kwargs[ + "metadata" + ] diff --git a/libs/community/tests/unit_tests/llms/test_yandex.py b/libs/community/tests/unit_tests/llms/test_yandex.py index aadbc6ad845..a694b2ab98a 100644 --- a/libs/community/tests/unit_tests/llms/test_yandex.py +++ b/libs/community/tests/unit_tests/llms/test_yandex.py @@ -1,3 +1,7 @@ +import os +from unittest import mock +from unittest.mock import MagicMock + import pytest from langchain_community.llms.yandex import YandexGPT @@ -36,3 +40,53 @@ def test_yandexgpt_invalid_model_params() -> None: api_key="your_api_key", # type: ignore[arg-type] model_uri="", ) + + +@pytest.mark.parametrize( + "api_key_or_token", [dict(api_key="bogus"), dict(iam_token="bogus")] +) +@pytest.mark.parametrize( + "disable_logging", + [dict(), dict(disable_request_logging=True), dict(disable_request_logging=False)], +) +@mock.patch.dict(os.environ, {}, clear=True) +def test_completion_call(api_key_or_token: dict, disable_logging: dict) -> None: + absent_yandex_module_stub = MagicMock() + grpc_mock = MagicMock() + with mock.patch.dict( + "sys.modules", + { + "yandex.cloud.ai.foundation_models.v1." + "text_common_pb2": absent_yandex_module_stub, + "yandex.cloud.ai.foundation_models.v1.text_generation." + "text_generation_service_pb2": absent_yandex_module_stub, + "yandex.cloud.ai.foundation_models.v1.text_generation." + "text_generation_service_pb2_grpc": absent_yandex_module_stub, + "grpc": grpc_mock, + }, + ): + grpc_mock.RpcError = Exception + stub = absent_yandex_module_stub.TextGenerationServiceStub + request_stub = absent_yandex_module_stub.CompletionRequest + msg_constructor_stub = absent_yandex_module_stub.Message + args = {"folder_id": "fldr", **api_key_or_token, **disable_logging} + ygpt = YandexGPT(**args) + grpc_call_mock = stub.return_value.Completion + msg_mock = mock.Mock() + msg_mock.message.text = "cmpltn" + res_mock = mock.Mock() + res_mock.alternatives = [msg_mock] + grpc_call_mock.return_value = [res_mock] + act_emb = ygpt.invoke("nomatter") + assert act_emb == "cmpltn" + assert len(grpc_call_mock.call_args_list) == 1 + once_called_args = grpc_call_mock.call_args_list[0] + act_model_uri = request_stub.call_args_list[0].kwargs["model_uri"] + act_text = msg_constructor_stub.call_args_list[0].kwargs["text"] + act_metadata = once_called_args.kwargs["metadata"] + assert "fldr" in act_model_uri + assert act_text == "nomatter" + assert act_metadata + assert len(act_metadata) > 0 + if disable_logging.get("disable_request_logging"): + assert ("x-data-logging-enabled", "false") in act_metadata