From d266b3ea4a5dc85fe895b93090d89aa311f8c48e Mon Sep 17 00:00:00 2001 From: fyasla <53271240+fyasla@users.noreply.github.com> Date: Fri, 10 Nov 2023 20:05:19 +0100 Subject: [PATCH] issue #12165 mask API key in chat_models/azureml_endpoint module (#12836) - **Description:** `AzureMLChatOnlineEndpoint` object from langchain/chat_models/azureml_endpoint.py safe to print without having any secrets included in raw format in the string representation. - **Issue:** #12165, - **Tag maintainer:** @eyurtsev --------- Co-authored-by: Faysal Bougamale Co-authored-by: Bagatur --- .../langchain/chat_models/azureml_endpoint.py | 14 ++-- .../chat_models/test_azureml_endpoint.py | 65 +++++++++++++++++++ 2 files changed, 73 insertions(+), 6 deletions(-) create mode 100644 libs/langchain/tests/unit_tests/chat_models/test_azureml_endpoint.py diff --git a/libs/langchain/langchain/chat_models/azureml_endpoint.py b/libs/langchain/langchain/chat_models/azureml_endpoint.py index 53bdc849252..8efa957ad0f 100644 --- a/libs/langchain/langchain/chat_models/azureml_endpoint.py +++ b/libs/langchain/langchain/chat_models/azureml_endpoint.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, cast from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.chat_models.base import SimpleChatModel from langchain.llms.azureml_endpoint import AzureMLEndpointClient, ContentFormatterBase -from langchain.pydantic_v1 import validator +from langchain.pydantic_v1 import SecretStr, validator from langchain.schema.messages import ( AIMessage, BaseMessage, @@ -12,7 +12,7 @@ from langchain.schema.messages import ( HumanMessage, SystemMessage, ) -from langchain.utils import get_from_dict_or_env +from langchain.utils import convert_to_secret_str, get_from_dict_or_env class LlamaContentFormatter(ContentFormatterBase): @@ -94,7 +94,7 @@ class AzureMLChatOnlineEndpoint(SimpleChatModel): """URL of pre-existing Endpoint. Should be passed to constructor or specified as env var `AZUREML_ENDPOINT_URL`.""" - endpoint_api_key: str = "" + endpoint_api_key: SecretStr = convert_to_secret_str("") """Authentication Key for Endpoint. Should be passed to constructor or specified as env var `AZUREML_ENDPOINT_API_KEY`.""" @@ -112,13 +112,15 @@ class AzureMLChatOnlineEndpoint(SimpleChatModel): @classmethod def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient: """Validate that api key and python package exist in environment.""" - endpoint_key = get_from_dict_or_env( - values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY" + values["endpoint_api_key"] = convert_to_secret_str( + get_from_dict_or_env(values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY") ) endpoint_url = get_from_dict_or_env( values, "endpoint_url", "AZUREML_ENDPOINT_URL" ) - http_client = AzureMLEndpointClient(endpoint_url, endpoint_key) + http_client = AzureMLEndpointClient( + endpoint_url, values["endpoint_api_key"].get_secret_value() + ) return http_client @property diff --git a/libs/langchain/tests/unit_tests/chat_models/test_azureml_endpoint.py b/libs/langchain/tests/unit_tests/chat_models/test_azureml_endpoint.py new file mode 100644 index 00000000000..de1055fc3d5 --- /dev/null +++ b/libs/langchain/tests/unit_tests/chat_models/test_azureml_endpoint.py @@ -0,0 +1,65 @@ +"""Test AzureML chat endpoint.""" + +import os + +import pytest +from pytest import CaptureFixture, FixtureRequest + +from langchain.chat_models.azureml_endpoint import AzureMLChatOnlineEndpoint +from langchain.pydantic_v1 import SecretStr + + +@pytest.fixture(scope="class") +def api_passed_via_environment_fixture() -> AzureMLChatOnlineEndpoint: + """Fixture to create an AzureMLChatOnlineEndpoint instance + with API key passed from environment""" + os.environ["AZUREML_ENDPOINT_API_KEY"] = "my-api-key" + azure_chat = AzureMLChatOnlineEndpoint( + endpoint_url="https://..inference.ml.azure.com/score" + ) + del os.environ["AZUREML_ENDPOINT_API_KEY"] + return azure_chat + + +@pytest.fixture(scope="class") +def api_passed_via_constructor_fixture() -> AzureMLChatOnlineEndpoint: + """Fixture to create an AzureMLChatOnlineEndpoint instance + with API key passed from constructor""" + azure_chat = AzureMLChatOnlineEndpoint( + endpoint_url="https://..inference.ml.azure.com/score", + endpoint_api_key="my-api-key", + ) + return azure_chat + + +@pytest.mark.parametrize( + "fixture_name", + ["api_passed_via_constructor_fixture", "api_passed_via_environment_fixture"], +) +class TestAzureMLChatOnlineEndpoint: + def test_api_key_is_secret_string( + self, fixture_name: str, request: FixtureRequest + ) -> None: + """Test that the API key is a SecretStr instance""" + azure_chat = request.getfixturevalue(fixture_name) + assert isinstance(azure_chat.endpoint_api_key, SecretStr) + + def test_api_key_masked( + self, fixture_name: str, request: FixtureRequest, capsys: CaptureFixture + ) -> None: + """Test that the API key is masked""" + azure_chat = request.getfixturevalue(fixture_name) + print(azure_chat.endpoint_api_key, end="") + captured = capsys.readouterr() + assert ( + (str(azure_chat.endpoint_api_key) == "**********") + and (repr(azure_chat.endpoint_api_key) == "SecretStr('**********')") + and (captured.out == "**********") + ) + + def test_api_key_is_readable( + self, fixture_name: str, request: FixtureRequest + ) -> None: + """Test that the real secret value of the API key can be read""" + azure_chat = request.getfixturevalue(fixture_name) + assert azure_chat.endpoint_api_key.get_secret_value() == "my-api-key"