diff --git a/libs/community/langchain_community/llms/predibase.py b/libs/community/langchain_community/llms/predibase.py index 2aaafd9128f..f04ce49398d 100644 --- a/libs/community/langchain_community/llms/predibase.py +++ b/libs/community/langchain_community/llms/predibase.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Mapping, Optional from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM -from langchain_core.pydantic_v1 import Field +from langchain_core.pydantic_v1 import Field, SecretStr class Predibase(LLM): @@ -13,7 +13,7 @@ class Predibase(LLM): """ model: str - predibase_api_key: str + predibase_api_key: SecretStr model_kwargs: Dict[str, Any] = Field(default_factory=dict) @property @@ -30,7 +30,7 @@ class Predibase(LLM): try: from predibase import PredibaseClient - pc = PredibaseClient(token=self.predibase_api_key) + pc = PredibaseClient(token=self.predibase_api_key.get_secret_value()) except ImportError as e: raise ImportError( "Could not import Predibase Python package. " diff --git a/libs/community/tests/integration_tests/llms/test_predibase.py b/libs/community/tests/integration_tests/llms/test_predibase.py new file mode 100644 index 00000000000..2f8b9233865 --- /dev/null +++ b/libs/community/tests/integration_tests/llms/test_predibase.py @@ -0,0 +1,19 @@ +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture + +from langchain_community.llms.predibase import Predibase + + +def test_api_key_is_string() -> None: + llm = Predibase(predibase_api_key="secret-api-key") + assert isinstance(llm.predibase_api_key, SecretStr) + + +def test_api_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + llm = Predibase(predibase_api_key="secret-api-key") + print(llm.predibase_api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********"