langchain[patch]: Mask API key for Arcee LLM (#14193)

- **Description:** Mask API key for Arcee LLM and its associated unit
tests
  - **Issue:** https://github.com/langchain-ai/langchain/issues/12165
  - **Dependencies:** N/A
  - **Tag maintainer:** @eyurtsev
  - **Twitter handle:** `eekaiboon`

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Kaiboon Ee 2023-12-03 15:11:43 -08:00 committed by GitHub
parent 092f302c0f
commit e961c57fd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 57 additions and 22 deletions

View File

@ -1,11 +1,11 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, cast
from langchain_core.pydantic_v1 import Extra, root_validator
from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.utilities.arcee import ArceeWrapper, DALMFilter
from langchain.utils import get_from_dict_or_env
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
class Arcee(LLM):
@ -30,7 +30,7 @@ class Arcee(LLM):
_client: Optional[ArceeWrapper] = None #: :meta private:
"""Arcee _client."""
arcee_api_key: str = ""
arcee_api_key: Optional[SecretStr] = None
"""Arcee API Key"""
model: str
@ -66,26 +66,25 @@ class Arcee(LLM):
"""Initializes private fields."""
super().__init__(**data)
self._client = None
self._client = ArceeWrapper(
arcee_api_key=self.arcee_api_key,
arcee_api_key=cast(SecretStr, self.arcee_api_key),
arcee_api_url=self.arcee_api_url,
arcee_api_version=self.arcee_api_version,
model_kwargs=self.model_kwargs,
model_name=self.model,
)
self._client.validate_model_training_status()
@root_validator()
def validate_environments(cls, values: Dict) -> Dict:
"""Validate Arcee environment variables."""
# validate env vars
values["arcee_api_key"] = get_from_dict_or_env(
values,
"arcee_api_key",
"ARCEE_API_KEY",
values["arcee_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"arcee_api_key",
"ARCEE_API_KEY",
)
)
values["arcee_api_url"] = get_from_dict_or_env(

View File

@ -1,12 +1,12 @@
from typing import Any, Dict, List, Optional
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Extra, root_validator
from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.utilities.arcee import ArceeWrapper, DALMFilter
from langchain.utils import get_from_dict_or_env
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
class ArceeRetriever(BaseRetriever):
@ -31,7 +31,7 @@ class ArceeRetriever(BaseRetriever):
_client: Optional[ArceeWrapper] = None #: :meta private:
"""Arcee client."""
arcee_api_key: str = ""
arcee_api_key: SecretStr
"""Arcee API Key"""
model: str
@ -75,10 +75,12 @@ class ArceeRetriever(BaseRetriever):
"""Validate Arcee environment variables."""
# validate env vars
values["arcee_api_key"] = get_from_dict_or_env(
values,
"arcee_api_key",
"ARCEE_API_KEY",
values["arcee_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"arcee_api_key",
"ARCEE_API_KEY",
)
)
values["arcee_api_url"] = get_from_dict_or_env(

View File

@ -6,7 +6,7 @@ from enum import Enum
from typing import Any, Dict, List, Literal, Mapping, Optional, Union
import requests
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
from langchain_core.retrievers import Document
@ -100,7 +100,7 @@ class ArceeWrapper:
def __init__(
self,
arcee_api_key: str,
arcee_api_key: SecretStr,
arcee_api_url: str,
arcee_api_version: str,
model_kwargs: Optional[Dict[str, Any]],
@ -167,7 +167,7 @@ class ArceeWrapper:
def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict:
headers = headers or {}
internal_headers = {
"X-Token": self.arcee_api_key,
"X-Token": self.arcee_api_key.get_secret_value(),
"Content-Type": "application/json",
}
headers.update(internal_headers)

View File

@ -0,0 +1,34 @@
"""Test Arcee llm"""
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch
from langchain.llms.arcee import Arcee
def test_api_key_is_secret_string() -> None:
llm = Arcee(model="DALM-PubMed", arcee_api_key="test-arcee-api-key")
assert isinstance(llm.arcee_api_key, SecretStr)
def test_api_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("ARCEE_API_KEY", "test-arcee-api-key")
llm = Arcee(model="DALM-PubMed")
print(llm.arcee_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
def test_api_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
"""Test initialization with an API key provided via the initializer"""
llm = Arcee(model="DALM-PubMed", arcee_api_key="test-arcee-api-key")
print(llm.arcee_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"