mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 13:40:46 +00:00
hide api key: arcee (#14304)
Hide API key for Arcee --------- Co-authored-by: raphael <raph.nunes95@gmail.com>
This commit is contained in:
parent
b2e756c0a8
commit
80637727ea
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List, Optional, cast
|
from typing import Any, Dict, List, Optional, Union, cast
|
||||||
|
|
||||||
from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
|
||||||
|
|
||||||
@ -30,7 +30,7 @@ class Arcee(LLM):
|
|||||||
_client: Optional[ArceeWrapper] = None #: :meta private:
|
_client: Optional[ArceeWrapper] = None #: :meta private:
|
||||||
"""Arcee _client."""
|
"""Arcee _client."""
|
||||||
|
|
||||||
arcee_api_key: Optional[SecretStr] = None
|
arcee_api_key: Union[SecretStr, str, None] = None
|
||||||
"""Arcee API Key"""
|
"""Arcee API Key"""
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
@ -66,15 +66,16 @@ class Arcee(LLM):
|
|||||||
"""Initializes private fields."""
|
"""Initializes private fields."""
|
||||||
|
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
|
api_key = cast(SecretStr, self.arcee_api_key)
|
||||||
self._client = ArceeWrapper(
|
self._client = ArceeWrapper(
|
||||||
arcee_api_key=cast(SecretStr, self.arcee_api_key),
|
arcee_api_key=api_key,
|
||||||
arcee_api_url=self.arcee_api_url,
|
arcee_api_url=self.arcee_api_url,
|
||||||
arcee_api_version=self.arcee_api_version,
|
arcee_api_version=self.arcee_api_version,
|
||||||
model_kwargs=self.model_kwargs,
|
model_kwargs=self.model_kwargs,
|
||||||
model_name=self.model,
|
model_name=self.model,
|
||||||
)
|
)
|
||||||
|
|
||||||
@root_validator()
|
@root_validator(pre=False)
|
||||||
def validate_environments(cls, values: Dict) -> Dict:
|
def validate_environments(cls, values: Dict) -> Dict:
|
||||||
"""Validate Arcee environment variables."""
|
"""Validate Arcee environment variables."""
|
||||||
|
|
||||||
@ -106,7 +107,7 @@ class Arcee(LLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# validate model kwargs
|
# validate model kwargs
|
||||||
if values["model_kwargs"]:
|
if values.get("model_kwargs"):
|
||||||
kw = values["model_kwargs"]
|
kw = values["model_kwargs"]
|
||||||
|
|
||||||
# validate size
|
# validate size
|
||||||
@ -120,7 +121,6 @@ class Arcee(LLM):
|
|||||||
raise ValueError("`filters` must be a list")
|
raise ValueError("`filters` must be a list")
|
||||||
for f in kw.get("filters"):
|
for f in kw.get("filters"):
|
||||||
DALMFilter(**f)
|
DALMFilter(**f)
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
|
@ -61,7 +61,7 @@ class ArceeRetriever(BaseRetriever):
|
|||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
|
|
||||||
self._client = ArceeWrapper(
|
self._client = ArceeWrapper(
|
||||||
arcee_api_key=self.arcee_api_key,
|
arcee_api_key=self.arcee_api_key.get_secret_value(),
|
||||||
arcee_api_url=self.arcee_api_url,
|
arcee_api_url=self.arcee_api_url,
|
||||||
arcee_api_version=self.arcee_api_version,
|
arcee_api_version=self.arcee_api_version,
|
||||||
model_kwargs=self.model_kwargs,
|
model_kwargs=self.model_kwargs,
|
||||||
|
@ -96,11 +96,14 @@ class ArceeDocumentAdapter:
|
|||||||
|
|
||||||
|
|
||||||
class ArceeWrapper:
|
class ArceeWrapper:
|
||||||
"""Wrapper for Arcee API."""
|
"""Wrapper for Arcee API.
|
||||||
|
|
||||||
|
For more details, see: https://www.arcee.ai/
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
arcee_api_key: SecretStr,
|
arcee_api_key: Union[str, SecretStr],
|
||||||
arcee_api_url: str,
|
arcee_api_url: str,
|
||||||
arcee_api_version: str,
|
arcee_api_version: str,
|
||||||
model_kwargs: Optional[Dict[str, Any]],
|
model_kwargs: Optional[Dict[str, Any]],
|
||||||
@ -114,9 +117,12 @@ class ArceeWrapper:
|
|||||||
arcee_api_version: Version of Arcee API.
|
arcee_api_version: Version of Arcee API.
|
||||||
model_kwargs: Keyword arguments for Arcee API.
|
model_kwargs: Keyword arguments for Arcee API.
|
||||||
model_name: Name of an Arcee model.
|
model_name: Name of an Arcee model.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.arcee_api_key = arcee_api_key
|
if isinstance(arcee_api_key, str):
|
||||||
|
arcee_api_key_ = SecretStr(arcee_api_key)
|
||||||
|
else:
|
||||||
|
arcee_api_key_ = arcee_api_key
|
||||||
|
self.arcee_api_key: SecretStr = arcee_api_key_
|
||||||
self.model_kwargs = model_kwargs
|
self.model_kwargs = model_kwargs
|
||||||
self.arcee_api_url = arcee_api_url
|
self.arcee_api_url = arcee_api_url
|
||||||
self.arcee_api_version = arcee_api_version
|
self.arcee_api_version = arcee_api_version
|
||||||
@ -166,8 +172,13 @@ class ArceeWrapper:
|
|||||||
|
|
||||||
def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict:
|
def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict:
|
||||||
headers = headers or {}
|
headers = headers or {}
|
||||||
|
if not isinstance(self.arcee_api_key, SecretStr):
|
||||||
|
raise TypeError(
|
||||||
|
f"arcee_api_key must be a SecretStr. Got {type(self.arcee_api_key)}"
|
||||||
|
)
|
||||||
|
api_key = self.arcee_api_key.get_secret_value()
|
||||||
internal_headers = {
|
internal_headers = {
|
||||||
"X-Token": self.arcee_api_key.get_secret_value(),
|
"X-Token": api_key,
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
headers.update(internal_headers)
|
headers.update(internal_headers)
|
||||||
|
@ -1,34 +1,70 @@
|
|||||||
"""Test Arcee llm"""
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from langchain_core.pydantic_v1 import SecretStr
|
from langchain_core.pydantic_v1 import SecretStr
|
||||||
from pytest import CaptureFixture, MonkeyPatch
|
from pytest import CaptureFixture, MonkeyPatch
|
||||||
|
|
||||||
from langchain.llms.arcee import Arcee
|
from langchain.llms.arcee import Arcee
|
||||||
|
|
||||||
|
|
||||||
def test_api_key_is_secret_string() -> None:
|
@patch("langchain.utilities.arcee.requests.get")
|
||||||
llm = Arcee(model="DALM-PubMed", arcee_api_key="test-arcee-api-key")
|
def test_arcee_api_key_is_secret_string(mock_get: MagicMock) -> None:
|
||||||
assert isinstance(llm.arcee_api_key, SecretStr)
|
mock_response = mock_get.return_value
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {
|
||||||
def test_api_key_masked_when_passed_from_env(
|
"model_id": "",
|
||||||
monkeypatch: MonkeyPatch, capsys: CaptureFixture
|
"status": "training_complete",
|
||||||
) -> None:
|
}
|
||||||
"""Test initialization with an API key provided via an env variable"""
|
|
||||||
monkeypatch.setenv("ARCEE_API_KEY", "test-arcee-api-key")
|
arcee_without_env_var = Arcee(
|
||||||
|
model="DALM-PubMed",
|
||||||
llm = Arcee(model="DALM-PubMed")
|
arcee_api_key="secret_api_key",
|
||||||
|
arcee_api_url="https://localhost",
|
||||||
print(llm.arcee_api_key, end="")
|
arcee_api_version="version",
|
||||||
captured = capsys.readouterr()
|
)
|
||||||
assert captured.out == "**********"
|
assert isinstance(arcee_without_env_var.arcee_api_key, SecretStr)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("langchain.utilities.arcee.requests.get")
|
||||||
def test_api_key_masked_when_passed_via_constructor(
|
def test_api_key_masked_when_passed_via_constructor(
|
||||||
capsys: CaptureFixture,
|
mock_get: MagicMock, capsys: CaptureFixture
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test initialization with an API key provided via the initializer"""
|
mock_response = mock_get.return_value
|
||||||
llm = Arcee(model="DALM-PubMed", arcee_api_key="test-arcee-api-key")
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"model_id": "",
|
||||||
|
"status": "training_complete",
|
||||||
|
}
|
||||||
|
|
||||||
print(llm.arcee_api_key, end="")
|
arcee_without_env_var = Arcee(
|
||||||
|
model="DALM-PubMed",
|
||||||
|
arcee_api_key="secret_api_key",
|
||||||
|
arcee_api_url="https://localhost",
|
||||||
|
arcee_api_version="version",
|
||||||
|
)
|
||||||
|
print(arcee_without_env_var.arcee_api_key, end="")
|
||||||
captured = capsys.readouterr()
|
captured = capsys.readouterr()
|
||||||
assert captured.out == "**********"
|
|
||||||
|
assert "**********" == captured.out
|
||||||
|
|
||||||
|
|
||||||
|
@patch("langchain.utilities.arcee.requests.get")
|
||||||
|
def test_api_key_masked_when_passed_from_env(
|
||||||
|
mock_get: MagicMock, capsys: CaptureFixture, monkeypatch: MonkeyPatch
|
||||||
|
) -> None:
|
||||||
|
mock_response = mock_get.return_value
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"model_id": "",
|
||||||
|
"status": "training_complete",
|
||||||
|
}
|
||||||
|
|
||||||
|
monkeypatch.setenv("ARCEE_API_KEY", "secret_api_key")
|
||||||
|
arcee_with_env_var = Arcee(
|
||||||
|
model="DALM-PubMed",
|
||||||
|
arcee_api_url="https://localhost",
|
||||||
|
arcee_api_version="version",
|
||||||
|
)
|
||||||
|
print(arcee_with_env_var.arcee_api_key, end="")
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
|
||||||
|
assert "**********" == captured.out
|
||||||
|
Loading…
Reference in New Issue
Block a user