mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-28 02:29:17 +00:00
- **Description:** `AzureMLChatOnlineEndpoint` object from langchain/chat_models/azureml_endpoint.py safe to print without having any secrets included in raw format in the string representation. - **Issue:** #12165, - **Tag maintainer:** @eyurtsev --------- Co-authored-by: Faysal Bougamale <faysal.bougamale@horiba.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
52f34de9b7
commit
d266b3ea4a
@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, cast
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import SimpleChatModel
|
||||
from langchain.llms.azureml_endpoint import AzureMLEndpointClient, ContentFormatterBase
|
||||
from langchain.pydantic_v1 import validator
|
||||
from langchain.pydantic_v1 import SecretStr, validator
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
@ -12,7 +12,7 @@ from langchain.schema.messages import (
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
|
||||
|
||||
class LlamaContentFormatter(ContentFormatterBase):
|
||||
@ -94,7 +94,7 @@ class AzureMLChatOnlineEndpoint(SimpleChatModel):
|
||||
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as
|
||||
env var `AZUREML_ENDPOINT_URL`."""
|
||||
|
||||
endpoint_api_key: str = ""
|
||||
endpoint_api_key: SecretStr = convert_to_secret_str("")
|
||||
"""Authentication Key for Endpoint. Should be passed to constructor or specified as
|
||||
env var `AZUREML_ENDPOINT_API_KEY`."""
|
||||
|
||||
@ -112,13 +112,15 @@ class AzureMLChatOnlineEndpoint(SimpleChatModel):
|
||||
@classmethod
|
||||
def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
|
||||
"""Validate that api key and python package exist in environment."""
|
||||
endpoint_key = get_from_dict_or_env(
|
||||
values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY"
|
||||
values["endpoint_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY")
|
||||
)
|
||||
endpoint_url = get_from_dict_or_env(
|
||||
values, "endpoint_url", "AZUREML_ENDPOINT_URL"
|
||||
)
|
||||
http_client = AzureMLEndpointClient(endpoint_url, endpoint_key)
|
||||
http_client = AzureMLEndpointClient(
|
||||
endpoint_url, values["endpoint_api_key"].get_secret_value()
|
||||
)
|
||||
return http_client
|
||||
|
||||
@property
|
||||
|
@ -0,0 +1,65 @@
|
||||
"""Test AzureML chat endpoint."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from pytest import CaptureFixture, FixtureRequest
|
||||
|
||||
from langchain.chat_models.azureml_endpoint import AzureMLChatOnlineEndpoint
|
||||
from langchain.pydantic_v1 import SecretStr
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def api_passed_via_environment_fixture() -> AzureMLChatOnlineEndpoint:
|
||||
"""Fixture to create an AzureMLChatOnlineEndpoint instance
|
||||
with API key passed from environment"""
|
||||
os.environ["AZUREML_ENDPOINT_API_KEY"] = "my-api-key"
|
||||
azure_chat = AzureMLChatOnlineEndpoint(
|
||||
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score"
|
||||
)
|
||||
del os.environ["AZUREML_ENDPOINT_API_KEY"]
|
||||
return azure_chat
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def api_passed_via_constructor_fixture() -> AzureMLChatOnlineEndpoint:
|
||||
"""Fixture to create an AzureMLChatOnlineEndpoint instance
|
||||
with API key passed from constructor"""
|
||||
azure_chat = AzureMLChatOnlineEndpoint(
|
||||
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
|
||||
endpoint_api_key="my-api-key",
|
||||
)
|
||||
return azure_chat
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"fixture_name",
|
||||
["api_passed_via_constructor_fixture", "api_passed_via_environment_fixture"],
|
||||
)
|
||||
class TestAzureMLChatOnlineEndpoint:
|
||||
def test_api_key_is_secret_string(
|
||||
self, fixture_name: str, request: FixtureRequest
|
||||
) -> None:
|
||||
"""Test that the API key is a SecretStr instance"""
|
||||
azure_chat = request.getfixturevalue(fixture_name)
|
||||
assert isinstance(azure_chat.endpoint_api_key, SecretStr)
|
||||
|
||||
def test_api_key_masked(
|
||||
self, fixture_name: str, request: FixtureRequest, capsys: CaptureFixture
|
||||
) -> None:
|
||||
"""Test that the API key is masked"""
|
||||
azure_chat = request.getfixturevalue(fixture_name)
|
||||
print(azure_chat.endpoint_api_key, end="")
|
||||
captured = capsys.readouterr()
|
||||
assert (
|
||||
(str(azure_chat.endpoint_api_key) == "**********")
|
||||
and (repr(azure_chat.endpoint_api_key) == "SecretStr('**********')")
|
||||
and (captured.out == "**********")
|
||||
)
|
||||
|
||||
def test_api_key_is_readable(
|
||||
self, fixture_name: str, request: FixtureRequest
|
||||
) -> None:
|
||||
"""Test that the real secret value of the API key can be read"""
|
||||
azure_chat = request.getfixturevalue(fixture_name)
|
||||
assert azure_chat.endpoint_api_key.get_secret_value() == "my-api-key"
|
Loading…
Reference in New Issue
Block a user