diff --git a/libs/langchain/langchain/llms/arcee.py b/libs/langchain/langchain/llms/arcee.py index 04b278f6763..7e83219cdd6 100644 --- a/libs/langchain/langchain/llms/arcee.py +++ b/libs/langchain/langchain/llms/arcee.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, cast +from typing import Any, Dict, List, Optional, Union, cast from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator @@ -30,7 +30,7 @@ class Arcee(LLM): _client: Optional[ArceeWrapper] = None #: :meta private: """Arcee _client.""" - arcee_api_key: Optional[SecretStr] = None + arcee_api_key: Union[SecretStr, str, None] = None """Arcee API Key""" model: str @@ -66,15 +66,16 @@ class Arcee(LLM): """Initializes private fields.""" super().__init__(**data) + api_key = cast(SecretStr, self.arcee_api_key) self._client = ArceeWrapper( - arcee_api_key=cast(SecretStr, self.arcee_api_key), + arcee_api_key=api_key, arcee_api_url=self.arcee_api_url, arcee_api_version=self.arcee_api_version, model_kwargs=self.model_kwargs, model_name=self.model, ) - @root_validator() + @root_validator(pre=False) def validate_environments(cls, values: Dict) -> Dict: """Validate Arcee environment variables.""" @@ -106,7 +107,7 @@ class Arcee(LLM): ) # validate model kwargs - if values["model_kwargs"]: + if values.get("model_kwargs"): kw = values["model_kwargs"] # validate size @@ -120,7 +121,6 @@ class Arcee(LLM): raise ValueError("`filters` must be a list") for f in kw.get("filters"): DALMFilter(**f) - return values def _call( diff --git a/libs/langchain/langchain/retrievers/arcee.py b/libs/langchain/langchain/retrievers/arcee.py index 7d3e7b822f5..e360f62a03f 100644 --- a/libs/langchain/langchain/retrievers/arcee.py +++ b/libs/langchain/langchain/retrievers/arcee.py @@ -61,7 +61,7 @@ class ArceeRetriever(BaseRetriever): super().__init__(**data) self._client = ArceeWrapper( - arcee_api_key=self.arcee_api_key, + arcee_api_key=self.arcee_api_key.get_secret_value(), arcee_api_url=self.arcee_api_url, arcee_api_version=self.arcee_api_version, model_kwargs=self.model_kwargs, diff --git a/libs/langchain/langchain/utilities/arcee.py b/libs/langchain/langchain/utilities/arcee.py index 743930b93e8..72170348583 100644 --- a/libs/langchain/langchain/utilities/arcee.py +++ b/libs/langchain/langchain/utilities/arcee.py @@ -96,11 +96,14 @@ class ArceeDocumentAdapter: class ArceeWrapper: - """Wrapper for Arcee API.""" + """Wrapper for Arcee API. + + For more details, see: https://www.arcee.ai/ + """ def __init__( self, - arcee_api_key: SecretStr, + arcee_api_key: Union[str, SecretStr], arcee_api_url: str, arcee_api_version: str, model_kwargs: Optional[Dict[str, Any]], @@ -114,9 +117,12 @@ class ArceeWrapper: arcee_api_version: Version of Arcee API. model_kwargs: Keyword arguments for Arcee API. model_name: Name of an Arcee model. - """ - self.arcee_api_key = arcee_api_key + if isinstance(arcee_api_key, str): + arcee_api_key_ = SecretStr(arcee_api_key) + else: + arcee_api_key_ = arcee_api_key + self.arcee_api_key: SecretStr = arcee_api_key_ self.model_kwargs = model_kwargs self.arcee_api_url = arcee_api_url self.arcee_api_version = arcee_api_version @@ -166,8 +172,13 @@ class ArceeWrapper: def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict: headers = headers or {} + if not isinstance(self.arcee_api_key, SecretStr): + raise TypeError( + f"arcee_api_key must be a SecretStr. Got {type(self.arcee_api_key)}" + ) + api_key = self.arcee_api_key.get_secret_value() internal_headers = { - "X-Token": self.arcee_api_key.get_secret_value(), + "X-Token": api_key, "Content-Type": "application/json", } headers.update(internal_headers) diff --git a/libs/langchain/tests/integration_tests/llms/test_arcee.py b/libs/langchain/tests/integration_tests/llms/test_arcee.py index a795c1909ca..40daec3682f 100644 --- a/libs/langchain/tests/integration_tests/llms/test_arcee.py +++ b/libs/langchain/tests/integration_tests/llms/test_arcee.py @@ -1,34 +1,70 @@ -"""Test Arcee llm""" +from unittest.mock import MagicMock, patch + from langchain_core.pydantic_v1 import SecretStr from pytest import CaptureFixture, MonkeyPatch from langchain.llms.arcee import Arcee -def test_api_key_is_secret_string() -> None: - llm = Arcee(model="DALM-PubMed", arcee_api_key="test-arcee-api-key") - assert isinstance(llm.arcee_api_key, SecretStr) - - -def test_api_key_masked_when_passed_from_env( - monkeypatch: MonkeyPatch, capsys: CaptureFixture -) -> None: - """Test initialization with an API key provided via an env variable""" - monkeypatch.setenv("ARCEE_API_KEY", "test-arcee-api-key") - - llm = Arcee(model="DALM-PubMed") - - print(llm.arcee_api_key, end="") - captured = capsys.readouterr() - assert captured.out == "**********" +@patch("langchain.utilities.arcee.requests.get") +def test_arcee_api_key_is_secret_string(mock_get: MagicMock) -> None: + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.json.return_value = { + "model_id": "", + "status": "training_complete", + } + + arcee_without_env_var = Arcee( + model="DALM-PubMed", + arcee_api_key="secret_api_key", + arcee_api_url="https://localhost", + arcee_api_version="version", + ) + assert isinstance(arcee_without_env_var.arcee_api_key, SecretStr) +@patch("langchain.utilities.arcee.requests.get") def test_api_key_masked_when_passed_via_constructor( - capsys: CaptureFixture, + mock_get: MagicMock, capsys: CaptureFixture ) -> None: - """Test initialization with an API key provided via the initializer""" - llm = Arcee(model="DALM-PubMed", arcee_api_key="test-arcee-api-key") + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.json.return_value = { + "model_id": "", + "status": "training_complete", + } - print(llm.arcee_api_key, end="") + arcee_without_env_var = Arcee( + model="DALM-PubMed", + arcee_api_key="secret_api_key", + arcee_api_url="https://localhost", + arcee_api_version="version", + ) + print(arcee_without_env_var.arcee_api_key, end="") captured = capsys.readouterr() - assert captured.out == "**********" + + assert "**********" == captured.out + + +@patch("langchain.utilities.arcee.requests.get") +def test_api_key_masked_when_passed_from_env( + mock_get: MagicMock, capsys: CaptureFixture, monkeypatch: MonkeyPatch +) -> None: + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.json.return_value = { + "model_id": "", + "status": "training_complete", + } + + monkeypatch.setenv("ARCEE_API_KEY", "secret_api_key") + arcee_with_env_var = Arcee( + model="DALM-PubMed", + arcee_api_url="https://localhost", + arcee_api_version="version", + ) + print(arcee_with_env_var.arcee_api_key, end="") + captured = capsys.readouterr() + + assert "**********" == captured.out