mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
- **Description:** `AzureMLChatOnlineEndpoint` object from langchain/chat_models/azureml_endpoint.py safe to print without having any secrets included in raw format in the string representation. - **Issue:** #12165, - **Tag maintainer:** @eyurtsev --------- Co-authored-by: Faysal Bougamale <faysal.bougamale@horiba.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
52f34de9b7
commit
d266b3ea4a
@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, cast
|
|||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
from langchain.chat_models.base import SimpleChatModel
|
from langchain.chat_models.base import SimpleChatModel
|
||||||
from langchain.llms.azureml_endpoint import AzureMLEndpointClient, ContentFormatterBase
|
from langchain.llms.azureml_endpoint import AzureMLEndpointClient, ContentFormatterBase
|
||||||
from langchain.pydantic_v1 import validator
|
from langchain.pydantic_v1 import SecretStr, validator
|
||||||
from langchain.schema.messages import (
|
from langchain.schema.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
@ -12,7 +12,7 @@ from langchain.schema.messages import (
|
|||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
class LlamaContentFormatter(ContentFormatterBase):
|
class LlamaContentFormatter(ContentFormatterBase):
|
||||||
@ -94,7 +94,7 @@ class AzureMLChatOnlineEndpoint(SimpleChatModel):
|
|||||||
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as
|
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as
|
||||||
env var `AZUREML_ENDPOINT_URL`."""
|
env var `AZUREML_ENDPOINT_URL`."""
|
||||||
|
|
||||||
endpoint_api_key: str = ""
|
endpoint_api_key: SecretStr = convert_to_secret_str("")
|
||||||
"""Authentication Key for Endpoint. Should be passed to constructor or specified as
|
"""Authentication Key for Endpoint. Should be passed to constructor or specified as
|
||||||
env var `AZUREML_ENDPOINT_API_KEY`."""
|
env var `AZUREML_ENDPOINT_API_KEY`."""
|
||||||
|
|
||||||
@ -112,13 +112,15 @@ class AzureMLChatOnlineEndpoint(SimpleChatModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
|
def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
|
||||||
"""Validate that api key and python package exist in environment."""
|
"""Validate that api key and python package exist in environment."""
|
||||||
endpoint_key = get_from_dict_or_env(
|
values["endpoint_api_key"] = convert_to_secret_str(
|
||||||
values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY"
|
get_from_dict_or_env(values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY")
|
||||||
)
|
)
|
||||||
endpoint_url = get_from_dict_or_env(
|
endpoint_url = get_from_dict_or_env(
|
||||||
values, "endpoint_url", "AZUREML_ENDPOINT_URL"
|
values, "endpoint_url", "AZUREML_ENDPOINT_URL"
|
||||||
)
|
)
|
||||||
http_client = AzureMLEndpointClient(endpoint_url, endpoint_key)
|
http_client = AzureMLEndpointClient(
|
||||||
|
endpoint_url, values["endpoint_api_key"].get_secret_value()
|
||||||
|
)
|
||||||
return http_client
|
return http_client
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -0,0 +1,65 @@
|
|||||||
|
"""Test AzureML chat endpoint."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pytest import CaptureFixture, FixtureRequest
|
||||||
|
|
||||||
|
from langchain.chat_models.azureml_endpoint import AzureMLChatOnlineEndpoint
|
||||||
|
from langchain.pydantic_v1 import SecretStr
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def api_passed_via_environment_fixture() -> AzureMLChatOnlineEndpoint:
|
||||||
|
"""Fixture to create an AzureMLChatOnlineEndpoint instance
|
||||||
|
with API key passed from environment"""
|
||||||
|
os.environ["AZUREML_ENDPOINT_API_KEY"] = "my-api-key"
|
||||||
|
azure_chat = AzureMLChatOnlineEndpoint(
|
||||||
|
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score"
|
||||||
|
)
|
||||||
|
del os.environ["AZUREML_ENDPOINT_API_KEY"]
|
||||||
|
return azure_chat
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def api_passed_via_constructor_fixture() -> AzureMLChatOnlineEndpoint:
|
||||||
|
"""Fixture to create an AzureMLChatOnlineEndpoint instance
|
||||||
|
with API key passed from constructor"""
|
||||||
|
azure_chat = AzureMLChatOnlineEndpoint(
|
||||||
|
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
|
||||||
|
endpoint_api_key="my-api-key",
|
||||||
|
)
|
||||||
|
return azure_chat
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"fixture_name",
|
||||||
|
["api_passed_via_constructor_fixture", "api_passed_via_environment_fixture"],
|
||||||
|
)
|
||||||
|
class TestAzureMLChatOnlineEndpoint:
|
||||||
|
def test_api_key_is_secret_string(
|
||||||
|
self, fixture_name: str, request: FixtureRequest
|
||||||
|
) -> None:
|
||||||
|
"""Test that the API key is a SecretStr instance"""
|
||||||
|
azure_chat = request.getfixturevalue(fixture_name)
|
||||||
|
assert isinstance(azure_chat.endpoint_api_key, SecretStr)
|
||||||
|
|
||||||
|
def test_api_key_masked(
|
||||||
|
self, fixture_name: str, request: FixtureRequest, capsys: CaptureFixture
|
||||||
|
) -> None:
|
||||||
|
"""Test that the API key is masked"""
|
||||||
|
azure_chat = request.getfixturevalue(fixture_name)
|
||||||
|
print(azure_chat.endpoint_api_key, end="")
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert (
|
||||||
|
(str(azure_chat.endpoint_api_key) == "**********")
|
||||||
|
and (repr(azure_chat.endpoint_api_key) == "SecretStr('**********')")
|
||||||
|
and (captured.out == "**********")
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_api_key_is_readable(
|
||||||
|
self, fixture_name: str, request: FixtureRequest
|
||||||
|
) -> None:
|
||||||
|
"""Test that the real secret value of the API key can be read"""
|
||||||
|
azure_chat = request.getfixturevalue(fixture_name)
|
||||||
|
assert azure_chat.endpoint_api_key.get_secret_value() == "my-api-key"
|
Loading…
Reference in New Issue
Block a user