feat: mask api_key for konko (#14010)

for https://github.com/langchain-ai/langchain/issues/12165
This commit is contained in:
chyroc 2024-01-02 05:42:49 +08:00 committed by GitHub
parent 62d32bd214
commit a4ae4bc361
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 59 additions and 13 deletions

View File

@ -25,8 +25,8 @@ from langchain_core.language_models.chat_models import (
) )
from langchain_core.messages import AIMessageChunk, BaseMessage from langchain_core.messages import AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import get_from_dict_or_env from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_community.adapters.openai import ( from langchain_community.adapters.openai import (
convert_dict_to_message, convert_dict_to_message,
@ -72,8 +72,8 @@ class ChatKonko(BaseChatModel):
"""What sampling temperature to use.""" """What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified.""" """Holds any model parameters valid for `create` call not explicitly specified."""
openai_api_key: Optional[str] = None openai_api_key: Optional[SecretStr] = None
konko_api_key: Optional[str] = None konko_api_key: Optional[SecretStr] = None
request_timeout: Optional[Union[float, Tuple[float, float]]] = None request_timeout: Optional[Union[float, Tuple[float, float]]] = None
"""Timeout for requests to Konko completion API.""" """Timeout for requests to Konko completion API."""
max_retries: int = 6 max_retries: int = 6
@ -88,8 +88,8 @@ class ChatKonko(BaseChatModel):
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
values["konko_api_key"] = get_from_dict_or_env( values["konko_api_key"] = convert_to_secret_str(
values, "konko_api_key", "KONKO_API_KEY" get_from_dict_or_env(values, "konko_api_key", "KONKO_API_KEY")
) )
try: try:
import konko import konko
@ -128,8 +128,8 @@ class ChatKonko(BaseChatModel):
@staticmethod @staticmethod
def get_available_models( def get_available_models(
konko_api_key: Optional[str] = None, konko_api_key: Union[str, SecretStr, None] = None,
openai_api_key: Optional[str] = None, openai_api_key: Union[str, SecretStr, None] = None,
konko_api_base: str = DEFAULT_API_BASE, konko_api_base: str = DEFAULT_API_BASE,
) -> Set[str]: ) -> Set[str]:
"""Get available models from Konko API.""" """Get available models from Konko API."""
@ -137,28 +137,32 @@ class ChatKonko(BaseChatModel):
# Try to retrieve the OpenAI API key if it's not passed as an argument # Try to retrieve the OpenAI API key if it's not passed as an argument
if not openai_api_key: if not openai_api_key:
try: try:
openai_api_key = os.environ["OPENAI_API_KEY"] openai_api_key = convert_to_secret_str(os.environ["OPENAI_API_KEY"])
except KeyError: except KeyError:
pass # It's okay if it's not set, we just won't use it pass # It's okay if it's not set, we just won't use it
elif isinstance(openai_api_key, str):
openai_api_key = convert_to_secret_str(openai_api_key)
# Try to retrieve the Konko API key if it's not passed as an argument # Try to retrieve the Konko API key if it's not passed as an argument
if not konko_api_key: if not konko_api_key:
try: try:
konko_api_key = os.environ["KONKO_API_KEY"] konko_api_key = convert_to_secret_str(os.environ["KONKO_API_KEY"])
except KeyError: except KeyError:
raise ValueError( raise ValueError(
"Konko API key must be passed as keyword argument or " "Konko API key must be passed as keyword argument or "
"set in environment variable KONKO_API_KEY." "set in environment variable KONKO_API_KEY."
) )
elif isinstance(konko_api_key, str):
konko_api_key = convert_to_secret_str(konko_api_key)
models_url = f"{konko_api_base}/models" models_url = f"{konko_api_base}/models"
headers = { headers = {
"Authorization": f"Bearer {konko_api_key}", "Authorization": f"Bearer {konko_api_key.get_secret_value()}",
} }
if openai_api_key: if openai_api_key:
headers["X-OpenAI-Api-Key"] = openai_api_key headers["X-OpenAI-Api-Key"] = openai_api_key.get_secret_value()
models_response = requests.get(models_url, headers=headers) models_response = requests.get(models_url, headers=headers)

View File

@ -1,15 +1,57 @@
"""Evaluate ChatKonko Interface.""" """Evaluate ChatKonko Interface."""
from typing import Any from typing import Any, cast
import pytest import pytest
from langchain_core.callbacks import CallbackManager from langchain_core.callbacks import CallbackManager
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch
from langchain_community.chat_models.konko import ChatKonko from langchain_community.chat_models.konko import ChatKonko
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
def test_konko_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("OPENAI_API_KEY", "test-openai-key")
monkeypatch.setenv("KONKO_API_KEY", "test-konko-key")
chat = ChatKonko()
print(chat.openai_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
print(chat.konko_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
def test_konko_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
"""Test initialization with an API key provided via the initializer"""
chat = ChatKonko(openai_api_key="test-openai-key", konko_api_key="test-konko-key")
print(chat.konko_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
print(chat.konko_secret_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
def test_uses_actual_secret_value_from_secret_str() -> None:
"""Test that actual secret is retrieved using `.get_secret_value()`."""
chat = ChatKonko(openai_api_key="test-openai-key", konko_api_key="test-konko-key")
assert cast(SecretStr, chat.konko_api_key).get_secret_value() == "test-openai-key"
assert cast(SecretStr, chat.konko_secret_key).get_secret_value() == "test-konko-key"
def test_konko_chat_test() -> None: def test_konko_chat_test() -> None:
"""Evaluate basic ChatKonko functionality.""" """Evaluate basic ChatKonko functionality."""
chat_instance = ChatKonko(max_tokens=10) chat_instance = ChatKonko(max_tokens=10)