From d63ceb65b3fd61065401c7bf25bbcb7df1419c1b Mon Sep 17 00:00:00 2001 From: chyroc Date: Wed, 27 Dec 2023 04:59:51 +0800 Subject: [PATCH] Refactor: use SecretStr for StochasticAI llms (#15118) --- .../langchain_community/llms/stochasticai.py | 14 +++++++------- .../unit_tests/llms/test_stochasticai.py | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+), 7 deletions(-) create mode 100644 libs/community/tests/unit_tests/llms/test_stochasticai.py diff --git a/libs/community/langchain_community/llms/stochasticai.py b/libs/community/langchain_community/llms/stochasticai.py index 0b3637e9aff..d645a019da7 100644 --- a/libs/community/langchain_community/llms/stochasticai.py +++ b/libs/community/langchain_community/llms/stochasticai.py @@ -5,8 +5,8 @@ from typing import Any, Dict, List, Mapping, Optional import requests from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM -from langchain_core.pydantic_v1 import Extra, Field, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_community.llms.utils import enforce_stop_tokens @@ -33,7 +33,7 @@ class StochasticAI(LLM): """Holds any model parameters valid for `create` call not explicitly specified.""" - stochasticai_api_key: Optional[str] = None + stochasticai_api_key: Optional[SecretStr] = None class Config: """Configuration for this pydantic object.""" @@ -61,8 +61,8 @@ class StochasticAI(LLM): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key exists in environment.""" - stochasticai_api_key = get_from_dict_or_env( - values, "stochasticai_api_key", "STOCHASTICAI_API_KEY" + stochasticai_api_key = convert_to_secret_str( + get_from_dict_or_env(values, "stochasticai_api_key", "STOCHASTICAI_API_KEY") ) values["stochasticai_api_key"] = stochasticai_api_key return values @@ -107,7 +107,7 @@ class StochasticAI(LLM): url=self.api_url, json={"prompt": prompt, "params": params}, headers={ - "apiKey": f"{self.stochasticai_api_key}", + "apiKey": f"{self.stochasticai_api_key.get_secret_value()}", "Accept": "application/json", "Content-Type": "application/json", }, @@ -119,7 +119,7 @@ class StochasticAI(LLM): response_get = requests.get( url=response_post_json["data"]["responseUrl"], headers={ - "apiKey": f"{self.stochasticai_api_key}", + "apiKey": f"{self.stochasticai_api_key.get_secret_value()}", "Accept": "application/json", "Content-Type": "application/json", }, diff --git a/libs/community/tests/unit_tests/llms/test_stochasticai.py b/libs/community/tests/unit_tests/llms/test_stochasticai.py new file mode 100644 index 00000000000..bf3e428c6b9 --- /dev/null +++ b/libs/community/tests/unit_tests/llms/test_stochasticai.py @@ -0,0 +1,19 @@ +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture + +from langchain_community.llms.stochasticai import StochasticAI + + +def test_api_key_is_string() -> None: + llm = StochasticAI(stochasticai_api_key="secret-api-key") + assert isinstance(llm.stochasticai_api_key, SecretStr) + + +def test_api_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + llm = StochasticAI(stochasticai_api_key="secret-api-key") + print(llm.stochasticai_api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********"