diff --git a/libs/langchain/langchain/llms/arcee.py b/libs/langchain/langchain/llms/arcee.py index 72028097c2a..04b278f6763 100644 --- a/libs/langchain/langchain/llms/arcee.py +++ b/libs/langchain/langchain/llms/arcee.py @@ -1,11 +1,11 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, cast -from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.utilities.arcee import ArceeWrapper, DALMFilter -from langchain.utils import get_from_dict_or_env +from langchain.utils import convert_to_secret_str, get_from_dict_or_env class Arcee(LLM): @@ -30,7 +30,7 @@ class Arcee(LLM): _client: Optional[ArceeWrapper] = None #: :meta private: """Arcee _client.""" - arcee_api_key: str = "" + arcee_api_key: Optional[SecretStr] = None """Arcee API Key""" model: str @@ -66,26 +66,25 @@ class Arcee(LLM): """Initializes private fields.""" super().__init__(**data) - self._client = None self._client = ArceeWrapper( - arcee_api_key=self.arcee_api_key, + arcee_api_key=cast(SecretStr, self.arcee_api_key), arcee_api_url=self.arcee_api_url, arcee_api_version=self.arcee_api_version, model_kwargs=self.model_kwargs, model_name=self.model, ) - self._client.validate_model_training_status() - @root_validator() def validate_environments(cls, values: Dict) -> Dict: """Validate Arcee environment variables.""" # validate env vars - values["arcee_api_key"] = get_from_dict_or_env( - values, - "arcee_api_key", - "ARCEE_API_KEY", + values["arcee_api_key"] = convert_to_secret_str( + get_from_dict_or_env( + values, + "arcee_api_key", + "ARCEE_API_KEY", + ) ) values["arcee_api_url"] = get_from_dict_or_env( diff --git a/libs/langchain/langchain/retrievers/arcee.py b/libs/langchain/langchain/retrievers/arcee.py index faf014eaf31..7d3e7b822f5 100644 --- a/libs/langchain/langchain/retrievers/arcee.py +++ b/libs/langchain/langchain/retrievers/arcee.py @@ -1,12 +1,12 @@ from typing import Any, Dict, List, Optional from langchain_core.documents import Document -from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator from langchain_core.retrievers import BaseRetriever from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.utilities.arcee import ArceeWrapper, DALMFilter -from langchain.utils import get_from_dict_or_env +from langchain.utils import convert_to_secret_str, get_from_dict_or_env class ArceeRetriever(BaseRetriever): @@ -31,7 +31,7 @@ class ArceeRetriever(BaseRetriever): _client: Optional[ArceeWrapper] = None #: :meta private: """Arcee client.""" - arcee_api_key: str = "" + arcee_api_key: SecretStr """Arcee API Key""" model: str @@ -75,10 +75,12 @@ class ArceeRetriever(BaseRetriever): """Validate Arcee environment variables.""" # validate env vars - values["arcee_api_key"] = get_from_dict_or_env( - values, - "arcee_api_key", - "ARCEE_API_KEY", + values["arcee_api_key"] = convert_to_secret_str( + get_from_dict_or_env( + values, + "arcee_api_key", + "ARCEE_API_KEY", + ) ) values["arcee_api_url"] = get_from_dict_or_env( diff --git a/libs/langchain/langchain/utilities/arcee.py b/libs/langchain/langchain/utilities/arcee.py index 318af14eb57..743930b93e8 100644 --- a/libs/langchain/langchain/utilities/arcee.py +++ b/libs/langchain/langchain/utilities/arcee.py @@ -6,7 +6,7 @@ from enum import Enum from typing import Any, Dict, List, Literal, Mapping, Optional, Union import requests -from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator from langchain_core.retrievers import Document @@ -100,7 +100,7 @@ class ArceeWrapper: def __init__( self, - arcee_api_key: str, + arcee_api_key: SecretStr, arcee_api_url: str, arcee_api_version: str, model_kwargs: Optional[Dict[str, Any]], @@ -167,7 +167,7 @@ class ArceeWrapper: def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict: headers = headers or {} internal_headers = { - "X-Token": self.arcee_api_key, + "X-Token": self.arcee_api_key.get_secret_value(), "Content-Type": "application/json", } headers.update(internal_headers) diff --git a/libs/langchain/tests/integration_tests/llms/test_arcee.py b/libs/langchain/tests/integration_tests/llms/test_arcee.py new file mode 100644 index 00000000000..a795c1909ca --- /dev/null +++ b/libs/langchain/tests/integration_tests/llms/test_arcee.py @@ -0,0 +1,34 @@ +"""Test Arcee llm""" +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture, MonkeyPatch + +from langchain.llms.arcee import Arcee + + +def test_api_key_is_secret_string() -> None: + llm = Arcee(model="DALM-PubMed", arcee_api_key="test-arcee-api-key") + assert isinstance(llm.arcee_api_key, SecretStr) + + +def test_api_key_masked_when_passed_from_env( + monkeypatch: MonkeyPatch, capsys: CaptureFixture +) -> None: + """Test initialization with an API key provided via an env variable""" + monkeypatch.setenv("ARCEE_API_KEY", "test-arcee-api-key") + + llm = Arcee(model="DALM-PubMed") + + print(llm.arcee_api_key, end="") + captured = capsys.readouterr() + assert captured.out == "**********" + + +def test_api_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + """Test initialization with an API key provided via the initializer""" + llm = Arcee(model="DALM-PubMed", arcee_api_key="test-arcee-api-key") + + print(llm.arcee_api_key, end="") + captured = capsys.readouterr() + assert captured.out == "**********"