langchain[patch]: Mask API key for Arcee LLM (#14193)

- **Description:** Mask API key for Arcee LLM and its associated unit
tests
  - **Issue:** https://github.com/langchain-ai/langchain/issues/12165
  - **Dependencies:** N/A
  - **Tag maintainer:** @eyurtsev
  - **Twitter handle:** `eekaiboon`

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Kaiboon Ee 2023-12-03 15:11:43 -08:00 committed by GitHub
parent 092f302c0f
commit e961c57fd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 57 additions and 22 deletions

View File

@ -1,11 +1,11 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, cast
from langchain_core.pydantic_v1 import Extra, root_validator from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.utilities.arcee import ArceeWrapper, DALMFilter from langchain.utilities.arcee import ArceeWrapper, DALMFilter
from langchain.utils import get_from_dict_or_env from langchain.utils import convert_to_secret_str, get_from_dict_or_env
class Arcee(LLM): class Arcee(LLM):
@ -30,7 +30,7 @@ class Arcee(LLM):
_client: Optional[ArceeWrapper] = None #: :meta private: _client: Optional[ArceeWrapper] = None #: :meta private:
"""Arcee _client.""" """Arcee _client."""
arcee_api_key: str = "" arcee_api_key: Optional[SecretStr] = None
"""Arcee API Key""" """Arcee API Key"""
model: str model: str
@ -66,26 +66,25 @@ class Arcee(LLM):
"""Initializes private fields.""" """Initializes private fields."""
super().__init__(**data) super().__init__(**data)
self._client = None
self._client = ArceeWrapper( self._client = ArceeWrapper(
arcee_api_key=self.arcee_api_key, arcee_api_key=cast(SecretStr, self.arcee_api_key),
arcee_api_url=self.arcee_api_url, arcee_api_url=self.arcee_api_url,
arcee_api_version=self.arcee_api_version, arcee_api_version=self.arcee_api_version,
model_kwargs=self.model_kwargs, model_kwargs=self.model_kwargs,
model_name=self.model, model_name=self.model,
) )
self._client.validate_model_training_status()
@root_validator() @root_validator()
def validate_environments(cls, values: Dict) -> Dict: def validate_environments(cls, values: Dict) -> Dict:
"""Validate Arcee environment variables.""" """Validate Arcee environment variables."""
# validate env vars # validate env vars
values["arcee_api_key"] = get_from_dict_or_env( values["arcee_api_key"] = convert_to_secret_str(
values, get_from_dict_or_env(
"arcee_api_key", values,
"ARCEE_API_KEY", "arcee_api_key",
"ARCEE_API_KEY",
)
) )
values["arcee_api_url"] = get_from_dict_or_env( values["arcee_api_url"] = get_from_dict_or_env(

View File

@ -1,12 +1,12 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Extra, root_validator from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
from langchain_core.retrievers import BaseRetriever from langchain_core.retrievers import BaseRetriever
from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.utilities.arcee import ArceeWrapper, DALMFilter from langchain.utilities.arcee import ArceeWrapper, DALMFilter
from langchain.utils import get_from_dict_or_env from langchain.utils import convert_to_secret_str, get_from_dict_or_env
class ArceeRetriever(BaseRetriever): class ArceeRetriever(BaseRetriever):
@ -31,7 +31,7 @@ class ArceeRetriever(BaseRetriever):
_client: Optional[ArceeWrapper] = None #: :meta private: _client: Optional[ArceeWrapper] = None #: :meta private:
"""Arcee client.""" """Arcee client."""
arcee_api_key: str = "" arcee_api_key: SecretStr
"""Arcee API Key""" """Arcee API Key"""
model: str model: str
@ -75,10 +75,12 @@ class ArceeRetriever(BaseRetriever):
"""Validate Arcee environment variables.""" """Validate Arcee environment variables."""
# validate env vars # validate env vars
values["arcee_api_key"] = get_from_dict_or_env( values["arcee_api_key"] = convert_to_secret_str(
values, get_from_dict_or_env(
"arcee_api_key", values,
"ARCEE_API_KEY", "arcee_api_key",
"ARCEE_API_KEY",
)
) )
values["arcee_api_url"] = get_from_dict_or_env( values["arcee_api_url"] = get_from_dict_or_env(

View File

@ -6,7 +6,7 @@ from enum import Enum
from typing import Any, Dict, List, Literal, Mapping, Optional, Union from typing import Any, Dict, List, Literal, Mapping, Optional, Union
import requests import requests
from langchain_core.pydantic_v1 import BaseModel, root_validator from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
from langchain_core.retrievers import Document from langchain_core.retrievers import Document
@ -100,7 +100,7 @@ class ArceeWrapper:
def __init__( def __init__(
self, self,
arcee_api_key: str, arcee_api_key: SecretStr,
arcee_api_url: str, arcee_api_url: str,
arcee_api_version: str, arcee_api_version: str,
model_kwargs: Optional[Dict[str, Any]], model_kwargs: Optional[Dict[str, Any]],
@ -167,7 +167,7 @@ class ArceeWrapper:
def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict: def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict:
headers = headers or {} headers = headers or {}
internal_headers = { internal_headers = {
"X-Token": self.arcee_api_key, "X-Token": self.arcee_api_key.get_secret_value(),
"Content-Type": "application/json", "Content-Type": "application/json",
} }
headers.update(internal_headers) headers.update(internal_headers)

View File

@ -0,0 +1,34 @@
"""Test Arcee llm"""
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch
from langchain.llms.arcee import Arcee
def test_api_key_is_secret_string() -> None:
llm = Arcee(model="DALM-PubMed", arcee_api_key="test-arcee-api-key")
assert isinstance(llm.arcee_api_key, SecretStr)
def test_api_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("ARCEE_API_KEY", "test-arcee-api-key")
llm = Arcee(model="DALM-PubMed")
print(llm.arcee_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
def test_api_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
"""Test initialization with an API key provided via the initializer"""
llm = Arcee(model="DALM-PubMed", arcee_api_key="test-arcee-api-key")
print(llm.arcee_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"