diff --git a/libs/community/langchain_community/llms/volcengine_maas.py b/libs/community/langchain_community/llms/volcengine_maas.py index 9f32005f464..444800e0496 100644 --- a/libs/community/langchain_community/llms/volcengine_maas.py +++ b/libs/community/langchain_community/llms/volcengine_maas.py @@ -5,8 +5,8 @@ from typing import Any, Dict, Iterator, List, Optional from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.outputs import GenerationChunk -from langchain_core.pydantic_v1 import BaseModel, Field, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env class VolcEngineMaasBase(BaseModel): @@ -14,9 +14,9 @@ class VolcEngineMaasBase(BaseModel): client: Any - volc_engine_maas_ak: Optional[str] = None + volc_engine_maas_ak: Optional[SecretStr] = None """access key for volc engine""" - volc_engine_maas_sk: Optional[str] = None + volc_engine_maas_sk: Optional[SecretStr] = None """secret key for volc engine""" endpoint: Optional[str] = "maas-api.ml-platform-cn-beijing.volces.com" @@ -54,8 +54,12 @@ class VolcEngineMaasBase(BaseModel): @root_validator() def validate_environment(cls, values: Dict) -> Dict: - ak = get_from_dict_or_env(values, "volc_engine_maas_ak", "VOLC_ACCESSKEY") - sk = get_from_dict_or_env(values, "volc_engine_maas_sk", "VOLC_SECRETKEY") + volc_engine_maas_ak = convert_to_secret_str( + get_from_dict_or_env(values, "volc_engine_maas_ak", "VOLC_ACCESSKEY") + ) + volc_engine_maas_sk = convert_to_secret_str( + get_from_dict_or_env(values, "volc_engine_maas_sk", "VOLC_SECRETKEY") + ) endpoint = values["endpoint"] if values["endpoint"] is not None and values["endpoint"] != "": endpoint = values["endpoint"] @@ -68,10 +72,11 @@ class VolcEngineMaasBase(BaseModel): connection_timeout=values["connect_timeout"], socket_timeout=values["read_timeout"], ) - maas.set_ak(ak) - values["volc_engine_maas_ak"] = ak - values["volc_engine_maas_sk"] = sk - maas.set_sk(sk) + maas.set_ak(volc_engine_maas_ak.get_secret_value()) + maas.set_sk(volc_engine_maas_sk.get_secret_value()) + + values["volc_engine_maas_ak"] = volc_engine_maas_ak + values["volc_engine_maas_sk"] = volc_engine_maas_sk values["client"] = maas except ImportError: raise ImportError( diff --git a/libs/community/tests/integration_tests/llms/test_volcengine_maas.py b/libs/community/tests/integration_tests/llms/test_volcengine_maas.py index 52a3acfe975..7d3094e627c 100644 --- a/libs/community/tests/integration_tests/llms/test_volcengine_maas.py +++ b/libs/community/tests/integration_tests/llms/test_volcengine_maas.py @@ -3,8 +3,35 @@ from typing import Generator from langchain_core.outputs import LLMResult +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture -from langchain_community.llms.volcengine_maas import VolcEngineMaasLLM +from langchain_community.llms.volcengine_maas import ( + VolcEngineMaasBase, + VolcEngineMaasLLM, +) + + +def test_api_key_is_string() -> None: + llm = VolcEngineMaasBase( + volc_engine_maas_ak="secret-volc-ak", + volc_engine_maas_sk="secret-volc-sk", + ) + assert isinstance(llm.volc_engine_maas_ak, SecretStr) + assert isinstance(llm.volc_engine_maas_sk, SecretStr) + + +def test_api_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + llm = VolcEngineMaasBase( + volc_engine_maas_ak="secret-volc-ak", + volc_engine_maas_sk="secret-volc-sk", + ) + print(llm.volc_engine_maas_ak, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" def test_default_call() -> None: