mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 13:18:12 +00:00
hide api key: arcee (#14304)
Hide API key for Arcee --------- Co-authored-by: raphael <raph.nunes95@gmail.com>
This commit is contained in:
parent
b2e756c0a8
commit
80637727ea
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
|
||||
from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
|
||||
|
||||
@ -30,7 +30,7 @@ class Arcee(LLM):
|
||||
_client: Optional[ArceeWrapper] = None #: :meta private:
|
||||
"""Arcee _client."""
|
||||
|
||||
arcee_api_key: Optional[SecretStr] = None
|
||||
arcee_api_key: Union[SecretStr, str, None] = None
|
||||
"""Arcee API Key"""
|
||||
|
||||
model: str
|
||||
@ -66,15 +66,16 @@ class Arcee(LLM):
|
||||
"""Initializes private fields."""
|
||||
|
||||
super().__init__(**data)
|
||||
api_key = cast(SecretStr, self.arcee_api_key)
|
||||
self._client = ArceeWrapper(
|
||||
arcee_api_key=cast(SecretStr, self.arcee_api_key),
|
||||
arcee_api_key=api_key,
|
||||
arcee_api_url=self.arcee_api_url,
|
||||
arcee_api_version=self.arcee_api_version,
|
||||
model_kwargs=self.model_kwargs,
|
||||
model_name=self.model,
|
||||
)
|
||||
|
||||
@root_validator()
|
||||
@root_validator(pre=False)
|
||||
def validate_environments(cls, values: Dict) -> Dict:
|
||||
"""Validate Arcee environment variables."""
|
||||
|
||||
@ -106,7 +107,7 @@ class Arcee(LLM):
|
||||
)
|
||||
|
||||
# validate model kwargs
|
||||
if values["model_kwargs"]:
|
||||
if values.get("model_kwargs"):
|
||||
kw = values["model_kwargs"]
|
||||
|
||||
# validate size
|
||||
@ -120,7 +121,6 @@ class Arcee(LLM):
|
||||
raise ValueError("`filters` must be a list")
|
||||
for f in kw.get("filters"):
|
||||
DALMFilter(**f)
|
||||
|
||||
return values
|
||||
|
||||
def _call(
|
||||
|
@ -61,7 +61,7 @@ class ArceeRetriever(BaseRetriever):
|
||||
super().__init__(**data)
|
||||
|
||||
self._client = ArceeWrapper(
|
||||
arcee_api_key=self.arcee_api_key,
|
||||
arcee_api_key=self.arcee_api_key.get_secret_value(),
|
||||
arcee_api_url=self.arcee_api_url,
|
||||
arcee_api_version=self.arcee_api_version,
|
||||
model_kwargs=self.model_kwargs,
|
||||
|
@ -96,11 +96,14 @@ class ArceeDocumentAdapter:
|
||||
|
||||
|
||||
class ArceeWrapper:
|
||||
"""Wrapper for Arcee API."""
|
||||
"""Wrapper for Arcee API.
|
||||
|
||||
For more details, see: https://www.arcee.ai/
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
arcee_api_key: SecretStr,
|
||||
arcee_api_key: Union[str, SecretStr],
|
||||
arcee_api_url: str,
|
||||
arcee_api_version: str,
|
||||
model_kwargs: Optional[Dict[str, Any]],
|
||||
@ -114,9 +117,12 @@ class ArceeWrapper:
|
||||
arcee_api_version: Version of Arcee API.
|
||||
model_kwargs: Keyword arguments for Arcee API.
|
||||
model_name: Name of an Arcee model.
|
||||
|
||||
"""
|
||||
self.arcee_api_key = arcee_api_key
|
||||
if isinstance(arcee_api_key, str):
|
||||
arcee_api_key_ = SecretStr(arcee_api_key)
|
||||
else:
|
||||
arcee_api_key_ = arcee_api_key
|
||||
self.arcee_api_key: SecretStr = arcee_api_key_
|
||||
self.model_kwargs = model_kwargs
|
||||
self.arcee_api_url = arcee_api_url
|
||||
self.arcee_api_version = arcee_api_version
|
||||
@ -166,8 +172,13 @@ class ArceeWrapper:
|
||||
|
||||
def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict:
|
||||
headers = headers or {}
|
||||
if not isinstance(self.arcee_api_key, SecretStr):
|
||||
raise TypeError(
|
||||
f"arcee_api_key must be a SecretStr. Got {type(self.arcee_api_key)}"
|
||||
)
|
||||
api_key = self.arcee_api_key.get_secret_value()
|
||||
internal_headers = {
|
||||
"X-Token": self.arcee_api_key.get_secret_value(),
|
||||
"X-Token": api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
headers.update(internal_headers)
|
||||
|
@ -1,34 +1,70 @@
|
||||
"""Test Arcee llm"""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pytest import CaptureFixture, MonkeyPatch
|
||||
|
||||
from langchain.llms.arcee import Arcee
|
||||
|
||||
|
||||
def test_api_key_is_secret_string() -> None:
|
||||
llm = Arcee(model="DALM-PubMed", arcee_api_key="test-arcee-api-key")
|
||||
assert isinstance(llm.arcee_api_key, SecretStr)
|
||||
|
||||
|
||||
def test_api_key_masked_when_passed_from_env(
|
||||
monkeypatch: MonkeyPatch, capsys: CaptureFixture
|
||||
) -> None:
|
||||
"""Test initialization with an API key provided via an env variable"""
|
||||
monkeypatch.setenv("ARCEE_API_KEY", "test-arcee-api-key")
|
||||
|
||||
llm = Arcee(model="DALM-PubMed")
|
||||
|
||||
print(llm.arcee_api_key, end="")
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == "**********"
|
||||
@patch("langchain.utilities.arcee.requests.get")
|
||||
def test_arcee_api_key_is_secret_string(mock_get: MagicMock) -> None:
|
||||
mock_response = mock_get.return_value
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"model_id": "",
|
||||
"status": "training_complete",
|
||||
}
|
||||
|
||||
arcee_without_env_var = Arcee(
|
||||
model="DALM-PubMed",
|
||||
arcee_api_key="secret_api_key",
|
||||
arcee_api_url="https://localhost",
|
||||
arcee_api_version="version",
|
||||
)
|
||||
assert isinstance(arcee_without_env_var.arcee_api_key, SecretStr)
|
||||
|
||||
|
||||
@patch("langchain.utilities.arcee.requests.get")
|
||||
def test_api_key_masked_when_passed_via_constructor(
|
||||
capsys: CaptureFixture,
|
||||
mock_get: MagicMock, capsys: CaptureFixture
|
||||
) -> None:
|
||||
"""Test initialization with an API key provided via the initializer"""
|
||||
llm = Arcee(model="DALM-PubMed", arcee_api_key="test-arcee-api-key")
|
||||
mock_response = mock_get.return_value
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"model_id": "",
|
||||
"status": "training_complete",
|
||||
}
|
||||
|
||||
print(llm.arcee_api_key, end="")
|
||||
arcee_without_env_var = Arcee(
|
||||
model="DALM-PubMed",
|
||||
arcee_api_key="secret_api_key",
|
||||
arcee_api_url="https://localhost",
|
||||
arcee_api_version="version",
|
||||
)
|
||||
print(arcee_without_env_var.arcee_api_key, end="")
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == "**********"
|
||||
|
||||
assert "**********" == captured.out
|
||||
|
||||
|
||||
@patch("langchain.utilities.arcee.requests.get")
|
||||
def test_api_key_masked_when_passed_from_env(
|
||||
mock_get: MagicMock, capsys: CaptureFixture, monkeypatch: MonkeyPatch
|
||||
) -> None:
|
||||
mock_response = mock_get.return_value
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"model_id": "",
|
||||
"status": "training_complete",
|
||||
}
|
||||
|
||||
monkeypatch.setenv("ARCEE_API_KEY", "secret_api_key")
|
||||
arcee_with_env_var = Arcee(
|
||||
model="DALM-PubMed",
|
||||
arcee_api_url="https://localhost",
|
||||
arcee_api_version="version",
|
||||
)
|
||||
print(arcee_with_env_var.arcee_api_key, end="")
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert "**********" == captured.out
|
||||
|
Loading…
Reference in New Issue
Block a user