From b6952d41e5bf57a24c6d9865322dd4f3d876d5cc Mon Sep 17 00:00:00 2001 From: chyroc Date: Tue, 2 Jan 2024 07:20:26 +0800 Subject: [PATCH] Refactor: use SecretStr for GPTRouter chat-model (#15101) --- .../chat_models/gpt_router.py | 23 ++++++++++--------- .../chat_models/test_gpt_router.py | 23 +++++++++++++++++++ 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/libs/community/langchain_community/chat_models/gpt_router.py b/libs/community/langchain_community/chat_models/gpt_router.py index ac91200ed4e..498d8542c8d 100644 --- a/libs/community/langchain_community/chat_models/gpt_router.py +++ b/libs/community/langchain_community/chat_models/gpt_router.py @@ -29,8 +29,8 @@ from langchain_core.language_models.chat_models import ( from langchain_core.language_models.llms import create_base_retry_decorator from langchain_core.messages import AIMessageChunk, BaseMessage from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import BaseModel, Field, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_community.adapters.openai import ( convert_dict_to_message, @@ -150,7 +150,7 @@ class GPTRouter(BaseChatModel): models_priority_list: List[GPTRouterModel] = Field(min_items=1) gpt_router_api_base: str = Field(default=None) """WriteSonic GPTRouter custom endpoint""" - gpt_router_api_key: Optional[str] = None + gpt_router_api_key: Optional[SecretStr] = None """WriteSonic GPTRouter API Key""" temperature: float = 0.7 """What sampling temperature to use.""" @@ -173,10 +173,12 @@ class GPTRouter(BaseChatModel): DEFAULT_API_BASE_URL, ) - values["gpt_router_api_key"] = get_from_dict_or_env( - values, - "gpt_router_api_key", - "GPT_ROUTER_API_KEY", + values["gpt_router_api_key"] = convert_to_secret_str( + get_from_dict_or_env( + values, + "gpt_router_api_key", + "GPT_ROUTER_API_KEY", + ) ) try: @@ -189,7 +191,8 @@ class GPTRouter(BaseChatModel): ) gpt_router_client = GPTRouterClient( - values["gpt_router_api_base"], values["gpt_router_api_key"] + values["gpt_router_api_base"], + values["gpt_router_api_key"].get_secret_value(), ) values["client"] = gpt_router_client @@ -197,9 +200,7 @@ class GPTRouter(BaseChatModel): @property def lc_secrets(self) -> Dict[str, str]: - return { - "gpt_router_api_key": "GPT_ROUTER_API_KEY", - } + return {"gpt_router_api_key": "GPT_ROUTER_API_KEY"} @property def lc_serializable(self) -> bool: diff --git a/libs/community/tests/integration_tests/chat_models/test_gpt_router.py b/libs/community/tests/integration_tests/chat_models/test_gpt_router.py index fb387154f46..785ad6ac0a1 100644 --- a/libs/community/tests/integration_tests/chat_models/test_gpt_router.py +++ b/libs/community/tests/integration_tests/chat_models/test_gpt_router.py @@ -7,11 +7,34 @@ from langchain_core.callbacks import ( ) from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.outputs import ChatGeneration, LLMResult +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture from langchain_community.chat_models.gpt_router import GPTRouter, GPTRouterModel from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler +def test_api_key_is_string() -> None: + gpt_router = GPTRouter( + gpt_router_api_base="https://example.com", + gpt_router_api_key="secret-api-key", + ) + assert isinstance(gpt_router.gpt_router_api_key, SecretStr) + + +def test_api_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + gpt_router = GPTRouter( + gpt_router_api_base="https://example.com", + gpt_router_api_key="secret-api-key", + ) + print(gpt_router.gpt_router_api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" + + def test_gpt_router_call() -> None: """Test valid call to GPTRouter.""" anthropic_claude = GPTRouterModel(