diff --git a/libs/community/langchain_community/chat_models/konko.py b/libs/community/langchain_community/chat_models/konko.py index ff88bd417f3..9fe24a50694 100644 --- a/libs/community/langchain_community/chat_models/konko.py +++ b/libs/community/langchain_community/chat_models/konko.py @@ -25,8 +25,8 @@ from langchain_core.language_models.chat_models import ( ) from langchain_core.messages import AIMessageChunk, BaseMessage from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import Field, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_community.adapters.openai import ( convert_dict_to_message, @@ -72,8 +72,8 @@ class ChatKonko(BaseChatModel): """What sampling temperature to use.""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" - openai_api_key: Optional[str] = None - konko_api_key: Optional[str] = None + openai_api_key: Optional[SecretStr] = None + konko_api_key: Optional[SecretStr] = None request_timeout: Optional[Union[float, Tuple[float, float]]] = None """Timeout for requests to Konko completion API.""" max_retries: int = 6 @@ -88,8 +88,8 @@ class ChatKonko(BaseChatModel): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - values["konko_api_key"] = get_from_dict_or_env( - values, "konko_api_key", "KONKO_API_KEY" + values["konko_api_key"] = convert_to_secret_str( + get_from_dict_or_env(values, "konko_api_key", "KONKO_API_KEY") ) try: import konko @@ -128,8 +128,8 @@ class ChatKonko(BaseChatModel): @staticmethod def get_available_models( - konko_api_key: Optional[str] = None, - openai_api_key: Optional[str] = None, + konko_api_key: Union[str, SecretStr, None] = None, + openai_api_key: Union[str, SecretStr, None] = None, konko_api_base: str = DEFAULT_API_BASE, ) -> Set[str]: """Get available models from Konko API.""" @@ -137,28 +137,32 @@ class ChatKonko(BaseChatModel): # Try to retrieve the OpenAI API key if it's not passed as an argument if not openai_api_key: try: - openai_api_key = os.environ["OPENAI_API_KEY"] + openai_api_key = convert_to_secret_str(os.environ["OPENAI_API_KEY"]) except KeyError: pass # It's okay if it's not set, we just won't use it + elif isinstance(openai_api_key, str): + openai_api_key = convert_to_secret_str(openai_api_key) # Try to retrieve the Konko API key if it's not passed as an argument if not konko_api_key: try: - konko_api_key = os.environ["KONKO_API_KEY"] + konko_api_key = convert_to_secret_str(os.environ["KONKO_API_KEY"]) except KeyError: raise ValueError( "Konko API key must be passed as keyword argument or " "set in environment variable KONKO_API_KEY." ) + elif isinstance(konko_api_key, str): + konko_api_key = convert_to_secret_str(konko_api_key) models_url = f"{konko_api_base}/models" headers = { - "Authorization": f"Bearer {konko_api_key}", + "Authorization": f"Bearer {konko_api_key.get_secret_value()}", } if openai_api_key: - headers["X-OpenAI-Api-Key"] = openai_api_key + headers["X-OpenAI-Api-Key"] = openai_api_key.get_secret_value() models_response = requests.get(models_url, headers=headers) diff --git a/libs/community/tests/integration_tests/chat_models/test_konko.py b/libs/community/tests/integration_tests/chat_models/test_konko.py index 47554199348..b87e709d208 100644 --- a/libs/community/tests/integration_tests/chat_models/test_konko.py +++ b/libs/community/tests/integration_tests/chat_models/test_konko.py @@ -1,15 +1,57 @@ """Evaluate ChatKonko Interface.""" -from typing import Any +from typing import Any, cast import pytest from langchain_core.callbacks import CallbackManager from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture, MonkeyPatch from langchain_community.chat_models.konko import ChatKonko from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler +def test_konko_key_masked_when_passed_from_env( + monkeypatch: MonkeyPatch, capsys: CaptureFixture +) -> None: + """Test initialization with an API key provided via an env variable""" + monkeypatch.setenv("OPENAI_API_KEY", "test-openai-key") + monkeypatch.setenv("KONKO_API_KEY", "test-konko-key") + + chat = ChatKonko() + + print(chat.openai_api_key, end="") + captured = capsys.readouterr() + assert captured.out == "**********" + + print(chat.konko_api_key, end="") + captured = capsys.readouterr() + assert captured.out == "**********" + + +def test_konko_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + """Test initialization with an API key provided via the initializer""" + chat = ChatKonko(openai_api_key="test-openai-key", konko_api_key="test-konko-key") + + print(chat.konko_api_key, end="") + captured = capsys.readouterr() + assert captured.out == "**********" + + print(chat.konko_secret_key, end="") + captured = capsys.readouterr() + assert captured.out == "**********" + + +def test_uses_actual_secret_value_from_secret_str() -> None: + """Test that actual secret is retrieved using `.get_secret_value()`.""" + chat = ChatKonko(openai_api_key="test-openai-key", konko_api_key="test-konko-key") + assert cast(SecretStr, chat.konko_api_key).get_secret_value() == "test-openai-key" + assert cast(SecretStr, chat.konko_secret_key).get_secret_value() == "test-konko-key" + + def test_konko_chat_test() -> None: """Evaluate basic ChatKonko functionality.""" chat_instance = ChatKonko(max_tokens=10)