hide api key: arcee (#14304)

Hide API key for Arcee

---------

Co-authored-by: raphael <raph.nunes95@gmail.com>
This commit is contained in:
Eugene Yurtsev 2023-12-05 14:49:55 -05:00 committed by GitHub
parent b2e756c0a8
commit 80637727ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 81 additions and 34 deletions

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, cast
from typing import Any, Dict, List, Optional, Union, cast
from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
@ -30,7 +30,7 @@ class Arcee(LLM):
_client: Optional[ArceeWrapper] = None #: :meta private:
"""Arcee _client."""
arcee_api_key: Optional[SecretStr] = None
arcee_api_key: Union[SecretStr, str, None] = None
"""Arcee API Key"""
model: str
@ -66,15 +66,16 @@ class Arcee(LLM):
"""Initializes private fields."""
super().__init__(**data)
api_key = cast(SecretStr, self.arcee_api_key)
self._client = ArceeWrapper(
arcee_api_key=cast(SecretStr, self.arcee_api_key),
arcee_api_key=api_key,
arcee_api_url=self.arcee_api_url,
arcee_api_version=self.arcee_api_version,
model_kwargs=self.model_kwargs,
model_name=self.model,
)
@root_validator()
@root_validator(pre=False)
def validate_environments(cls, values: Dict) -> Dict:
"""Validate Arcee environment variables."""
@ -106,7 +107,7 @@ class Arcee(LLM):
)
# validate model kwargs
if values["model_kwargs"]:
if values.get("model_kwargs"):
kw = values["model_kwargs"]
# validate size
@ -120,7 +121,6 @@ class Arcee(LLM):
raise ValueError("`filters` must be a list")
for f in kw.get("filters"):
DALMFilter(**f)
return values
def _call(

View File

@ -61,7 +61,7 @@ class ArceeRetriever(BaseRetriever):
super().__init__(**data)
self._client = ArceeWrapper(
arcee_api_key=self.arcee_api_key,
arcee_api_key=self.arcee_api_key.get_secret_value(),
arcee_api_url=self.arcee_api_url,
arcee_api_version=self.arcee_api_version,
model_kwargs=self.model_kwargs,

View File

@ -96,11 +96,14 @@ class ArceeDocumentAdapter:
class ArceeWrapper:
"""Wrapper for Arcee API."""
"""Wrapper for Arcee API.
For more details, see: https://www.arcee.ai/
"""
def __init__(
self,
arcee_api_key: SecretStr,
arcee_api_key: Union[str, SecretStr],
arcee_api_url: str,
arcee_api_version: str,
model_kwargs: Optional[Dict[str, Any]],
@ -114,9 +117,12 @@ class ArceeWrapper:
arcee_api_version: Version of Arcee API.
model_kwargs: Keyword arguments for Arcee API.
model_name: Name of an Arcee model.
"""
self.arcee_api_key = arcee_api_key
if isinstance(arcee_api_key, str):
arcee_api_key_ = SecretStr(arcee_api_key)
else:
arcee_api_key_ = arcee_api_key
self.arcee_api_key: SecretStr = arcee_api_key_
self.model_kwargs = model_kwargs
self.arcee_api_url = arcee_api_url
self.arcee_api_version = arcee_api_version
@ -166,8 +172,13 @@ class ArceeWrapper:
def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict:
headers = headers or {}
if not isinstance(self.arcee_api_key, SecretStr):
raise TypeError(
f"arcee_api_key must be a SecretStr. Got {type(self.arcee_api_key)}"
)
api_key = self.arcee_api_key.get_secret_value()
internal_headers = {
"X-Token": self.arcee_api_key.get_secret_value(),
"X-Token": api_key,
"Content-Type": "application/json",
}
headers.update(internal_headers)

View File

@ -1,34 +1,70 @@
"""Test Arcee llm"""
from unittest.mock import MagicMock, patch
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch
from langchain.llms.arcee import Arcee
def test_api_key_is_secret_string() -> None:
llm = Arcee(model="DALM-PubMed", arcee_api_key="test-arcee-api-key")
assert isinstance(llm.arcee_api_key, SecretStr)
def test_api_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("ARCEE_API_KEY", "test-arcee-api-key")
llm = Arcee(model="DALM-PubMed")
print(llm.arcee_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
@patch("langchain.utilities.arcee.requests.get")
def test_arcee_api_key_is_secret_string(mock_get: MagicMock) -> None:
mock_response = mock_get.return_value
mock_response.status_code = 200
mock_response.json.return_value = {
"model_id": "",
"status": "training_complete",
}
arcee_without_env_var = Arcee(
model="DALM-PubMed",
arcee_api_key="secret_api_key",
arcee_api_url="https://localhost",
arcee_api_version="version",
)
assert isinstance(arcee_without_env_var.arcee_api_key, SecretStr)
@patch("langchain.utilities.arcee.requests.get")
def test_api_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
mock_get: MagicMock, capsys: CaptureFixture
) -> None:
"""Test initialization with an API key provided via the initializer"""
llm = Arcee(model="DALM-PubMed", arcee_api_key="test-arcee-api-key")
mock_response = mock_get.return_value
mock_response.status_code = 200
mock_response.json.return_value = {
"model_id": "",
"status": "training_complete",
}
print(llm.arcee_api_key, end="")
arcee_without_env_var = Arcee(
model="DALM-PubMed",
arcee_api_key="secret_api_key",
arcee_api_url="https://localhost",
arcee_api_version="version",
)
print(arcee_without_env_var.arcee_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
assert "**********" == captured.out
@patch("langchain.utilities.arcee.requests.get")
def test_api_key_masked_when_passed_from_env(
mock_get: MagicMock, capsys: CaptureFixture, monkeypatch: MonkeyPatch
) -> None:
mock_response = mock_get.return_value
mock_response.status_code = 200
mock_response.json.return_value = {
"model_id": "",
"status": "training_complete",
}
monkeypatch.setenv("ARCEE_API_KEY", "secret_api_key")
arcee_with_env_var = Arcee(
model="DALM-PubMed",
arcee_api_url="https://localhost",
arcee_api_version="version",
)
print(arcee_with_env_var.arcee_api_key, end="")
captured = capsys.readouterr()
assert "**********" == captured.out