mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-01 04:29:09 +00:00
langchain[patch]: Mask API key for Arcee LLM (#14193)
- **Description:** Mask API key for Arcee LLM and its associated unit tests - **Issue:** https://github.com/langchain-ai/langchain/issues/12165 - **Dependencies:** N/A - **Tag maintainer:** @eyurtsev - **Twitter handle:** `eekaiboon` --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
092f302c0f
commit
e961c57fd2
@ -1,11 +1,11 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from langchain_core.pydantic_v1 import Extra, root_validator
|
||||
from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.utilities.arcee import ArceeWrapper, DALMFilter
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
|
||||
|
||||
class Arcee(LLM):
|
||||
@ -30,7 +30,7 @@ class Arcee(LLM):
|
||||
_client: Optional[ArceeWrapper] = None #: :meta private:
|
||||
"""Arcee _client."""
|
||||
|
||||
arcee_api_key: str = ""
|
||||
arcee_api_key: Optional[SecretStr] = None
|
||||
"""Arcee API Key"""
|
||||
|
||||
model: str
|
||||
@ -66,26 +66,25 @@ class Arcee(LLM):
|
||||
"""Initializes private fields."""
|
||||
|
||||
super().__init__(**data)
|
||||
self._client = None
|
||||
self._client = ArceeWrapper(
|
||||
arcee_api_key=self.arcee_api_key,
|
||||
arcee_api_key=cast(SecretStr, self.arcee_api_key),
|
||||
arcee_api_url=self.arcee_api_url,
|
||||
arcee_api_version=self.arcee_api_version,
|
||||
model_kwargs=self.model_kwargs,
|
||||
model_name=self.model,
|
||||
)
|
||||
|
||||
self._client.validate_model_training_status()
|
||||
|
||||
@root_validator()
|
||||
def validate_environments(cls, values: Dict) -> Dict:
|
||||
"""Validate Arcee environment variables."""
|
||||
|
||||
# validate env vars
|
||||
values["arcee_api_key"] = get_from_dict_or_env(
|
||||
values,
|
||||
"arcee_api_key",
|
||||
"ARCEE_API_KEY",
|
||||
values["arcee_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values,
|
||||
"arcee_api_key",
|
||||
"ARCEE_API_KEY",
|
||||
)
|
||||
)
|
||||
|
||||
values["arcee_api_url"] = get_from_dict_or_env(
|
||||
|
@ -1,12 +1,12 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.pydantic_v1 import Extra, root_validator
|
||||
from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.utilities.arcee import ArceeWrapper, DALMFilter
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
|
||||
|
||||
class ArceeRetriever(BaseRetriever):
|
||||
@ -31,7 +31,7 @@ class ArceeRetriever(BaseRetriever):
|
||||
_client: Optional[ArceeWrapper] = None #: :meta private:
|
||||
"""Arcee client."""
|
||||
|
||||
arcee_api_key: str = ""
|
||||
arcee_api_key: SecretStr
|
||||
"""Arcee API Key"""
|
||||
|
||||
model: str
|
||||
@ -75,10 +75,12 @@ class ArceeRetriever(BaseRetriever):
|
||||
"""Validate Arcee environment variables."""
|
||||
|
||||
# validate env vars
|
||||
values["arcee_api_key"] = get_from_dict_or_env(
|
||||
values,
|
||||
"arcee_api_key",
|
||||
"ARCEE_API_KEY",
|
||||
values["arcee_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values,
|
||||
"arcee_api_key",
|
||||
"ARCEE_API_KEY",
|
||||
)
|
||||
)
|
||||
|
||||
values["arcee_api_url"] = get_from_dict_or_env(
|
||||
|
@ -6,7 +6,7 @@ from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Mapping, Optional, Union
|
||||
|
||||
import requests
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
|
||||
from langchain_core.retrievers import Document
|
||||
|
||||
|
||||
@ -100,7 +100,7 @@ class ArceeWrapper:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
arcee_api_key: str,
|
||||
arcee_api_key: SecretStr,
|
||||
arcee_api_url: str,
|
||||
arcee_api_version: str,
|
||||
model_kwargs: Optional[Dict[str, Any]],
|
||||
@ -167,7 +167,7 @@ class ArceeWrapper:
|
||||
def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict:
|
||||
headers = headers or {}
|
||||
internal_headers = {
|
||||
"X-Token": self.arcee_api_key,
|
||||
"X-Token": self.arcee_api_key.get_secret_value(),
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
headers.update(internal_headers)
|
||||
|
34
libs/langchain/tests/integration_tests/llms/test_arcee.py
Normal file
34
libs/langchain/tests/integration_tests/llms/test_arcee.py
Normal file
@ -0,0 +1,34 @@
|
||||
"""Test Arcee llm"""
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pytest import CaptureFixture, MonkeyPatch
|
||||
|
||||
from langchain.llms.arcee import Arcee
|
||||
|
||||
|
||||
def test_api_key_is_secret_string() -> None:
|
||||
llm = Arcee(model="DALM-PubMed", arcee_api_key="test-arcee-api-key")
|
||||
assert isinstance(llm.arcee_api_key, SecretStr)
|
||||
|
||||
|
||||
def test_api_key_masked_when_passed_from_env(
|
||||
monkeypatch: MonkeyPatch, capsys: CaptureFixture
|
||||
) -> None:
|
||||
"""Test initialization with an API key provided via an env variable"""
|
||||
monkeypatch.setenv("ARCEE_API_KEY", "test-arcee-api-key")
|
||||
|
||||
llm = Arcee(model="DALM-PubMed")
|
||||
|
||||
print(llm.arcee_api_key, end="")
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
def test_api_key_masked_when_passed_via_constructor(
|
||||
capsys: CaptureFixture,
|
||||
) -> None:
|
||||
"""Test initialization with an API key provided via the initializer"""
|
||||
llm = Arcee(model="DALM-PubMed", arcee_api_key="test-arcee-api-key")
|
||||
|
||||
print(llm.arcee_api_key, end="")
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == "**********"
|
Loading…
Reference in New Issue
Block a user