feat: mask api key for cerebriumai llm (#14272)

- **Description:** Masking API key for CerebriumAI LLM to protect user
secrets.
 - **Issue:** #12165 
 - **Dependencies:** None
 - **Tag maintainer:** @eyurtsev

---------

Signed-off-by: Yuchen Liang <yuchenl3@andrew.cmu.edu>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Yuchen Liang 2023-12-06 12:06:00 -05:00 committed by GitHub
parent d4d64daa1e
commit ad6dfb6220
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 10 deletions

View File

@ -1,13 +1,13 @@
import logging import logging
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional, cast
import requests import requests
from langchain_core.pydantic_v1 import Extra, Field, root_validator from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env from langchain.utils import convert_to_secret_str, get_from_dict_or_env
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -15,8 +15,9 @@ logger = logging.getLogger(__name__)
class CerebriumAI(LLM): class CerebriumAI(LLM):
"""CerebriumAI large language models. """CerebriumAI large language models.
To use, you should have the ``cerebrium`` python package installed, and the To use, you should have the ``cerebrium`` python package installed.
environment variable ``CEREBRIUMAI_API_KEY`` set with your API key. You should also have the environment variable ``CEREBRIUMAI_API_KEY``
set with your API key or pass it as a named argument in the constructor.
Any parameters that are valid to be passed to the call can be passed Any parameters that are valid to be passed to the call can be passed
in, even if not explicitly saved on this class. in, even if not explicitly saved on this class.
@ -25,7 +26,7 @@ class CerebriumAI(LLM):
.. code-block:: python .. code-block:: python
from langchain.llms import CerebriumAI from langchain.llms import CerebriumAI
cerebrium = CerebriumAI(endpoint_url="") cerebrium = CerebriumAI(endpoint_url="", cerebriumai_api_key="my-api-key")
""" """
@ -36,7 +37,7 @@ class CerebriumAI(LLM):
"""Holds any model parameters valid for `create` call not """Holds any model parameters valid for `create` call not
explicitly specified.""" explicitly specified."""
cerebriumai_api_key: Optional[str] = None cerebriumai_api_key: Optional[SecretStr] = None
class Config: class Config:
"""Configuration for this pydantic config.""" """Configuration for this pydantic config."""
@ -64,8 +65,8 @@ class CerebriumAI(LLM):
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
cerebriumai_api_key = get_from_dict_or_env( cerebriumai_api_key = convert_to_secret_str(
values, "cerebriumai_api_key", "CEREBRIUMAI_API_KEY" get_from_dict_or_env(values, "cerebriumai_api_key", "CEREBRIUMAI_API_KEY")
) )
values["cerebriumai_api_key"] = cerebriumai_api_key values["cerebriumai_api_key"] = cerebriumai_api_key
return values return values
@ -91,7 +92,9 @@ class CerebriumAI(LLM):
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
headers: Dict = { headers: Dict = {
"Authorization": self.cerebriumai_api_key, "Authorization": cast(
SecretStr, self.cerebriumai_api_key
).get_secret_value(),
"Content-Type": "application/json", "Content-Type": "application/json",
} }
params = self.model_kwargs or {} params = self.model_kwargs or {}

View File

@ -0,0 +1,33 @@
"""Test CerebriumAI llm"""
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch
from langchain.llms.cerebriumai import CerebriumAI
def test_api_key_is_secret_string() -> None:
llm = CerebriumAI(cerebriumai_api_key="test-cerebriumai-api-key")
assert isinstance(llm.cerebriumai_api_key, SecretStr)
def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> None:
llm = CerebriumAI(cerebriumai_api_key="secret-api-key")
print(llm.cerebriumai_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
assert repr(llm.cerebriumai_api_key) == "SecretStr('**********')"
def test_api_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
monkeypatch.setenv("CEREBRIUMAI_API_KEY", "secret-api-key")
llm = CerebriumAI()
print(llm.cerebriumai_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
assert repr(llm.cerebriumai_api_key) == "SecretStr('**********')"