From f97ab84c6b32e8d9ae9e9ab48e708170bf9cb6c9 Mon Sep 17 00:00:00 2001 From: chyroc Date: Wed, 29 Nov 2023 10:24:50 +0800 Subject: [PATCH] Merge pull request #13907 * feat: mask api_key for jina --- .../langchain/chat_models/jinachat.py | 13 ++++--- .../chat_models/test_jinachat.py | 37 +++++++++++++++++++ 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/libs/langchain/langchain/chat_models/jinachat.py b/libs/langchain/langchain/chat_models/jinachat.py index 4ad241ea35d..bdb77e25b20 100644 --- a/libs/langchain/langchain/chat_models/jinachat.py +++ b/libs/langchain/langchain/chat_models/jinachat.py @@ -30,8 +30,8 @@ from langchain_core.messages import ( SystemMessageChunk, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import Field, root_validator -from langchain_core.utils import get_pydantic_field_names +from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.utils import convert_to_secret_str, get_pydantic_field_names from tenacity import ( before_sleep_log, retry, @@ -172,7 +172,7 @@ class JinaChat(BaseChatModel): """What sampling temperature to use.""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" - jinachat_api_key: Optional[str] = None + jinachat_api_key: Optional[SecretStr] = None """Base URL path for API requests, leave blank if not using a proxy or service emulator.""" request_timeout: Optional[Union[float, Tuple[float, float]]] = None @@ -218,8 +218,8 @@ class JinaChat(BaseChatModel): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - values["jinachat_api_key"] = get_from_dict_or_env( - values, "jinachat_api_key", "JINACHAT_API_KEY" + values["jinachat_api_key"] = convert_to_secret_str( + get_from_dict_or_env(values, "jinachat_api_key", "JINACHAT_API_KEY") ) try: import openai @@ -395,7 +395,8 @@ class JinaChat(BaseChatModel): def _invocation_params(self) -> Mapping[str, Any]: """Get the parameters used to invoke the model.""" jinachat_creds: Dict[str, Any] = { - "api_key": self.jinachat_api_key, + "api_key": self.jinachat_api_key + and self.jinachat_api_key.get_secret_value(), "api_base": "https://api.chat.jina.ai/v1", "model": "jinachat", } diff --git a/libs/langchain/tests/integration_tests/chat_models/test_jinachat.py b/libs/langchain/tests/integration_tests/chat_models/test_jinachat.py index 8b7cdc129c1..85ec7c9d45f 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_jinachat.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_jinachat.py @@ -1,15 +1,52 @@ """Test JinaChat wrapper.""" +from typing import cast import pytest from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage from langchain_core.outputs import ChatGeneration, LLMResult +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture, MonkeyPatch from langchain.callbacks.manager import CallbackManager from langchain.chat_models.jinachat import JinaChat from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler +def test_jinachat_api_key_is_secret_string() -> None: + llm = JinaChat(jinachat_api_key="secret-api-key") + assert isinstance(llm.jinachat_api_key, SecretStr) + + +def test_jinachat_api_key_masked_when_passed_from_env( + monkeypatch: MonkeyPatch, capsys: CaptureFixture +) -> None: + """Test initialization with an API key provided via an env variable""" + monkeypatch.setenv("JINACHAT_API_KEY", "secret-api-key") + llm = JinaChat() + print(llm.jinachat_api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" + + +def test_jinachat_api_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + """Test initialization with an API key provided via the initializer""" + llm = JinaChat(jinachat_api_key="secret-api-key") + print(llm.jinachat_api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" + + +def test_uses_actual_secret_value_from_secretstr() -> None: + """Test that actual secret is retrieved using `.get_secret_value()`.""" + llm = JinaChat(jinachat_api_key="secret-api-key") + assert cast(SecretStr, llm.jinachat_api_key).get_secret_value() == "secret-api-key" + + def test_jinachat() -> None: """Test JinaChat wrapper.""" chat = JinaChat(max_tokens=10)