mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
feat: mask api key for cerebriumai llm (#14272)
- **Description:** Masking API key for CerebriumAI LLM to protect user secrets. - **Issue:** #12165 - **Dependencies:** None - **Tag maintainer:** @eyurtsev --------- Signed-off-by: Yuchen Liang <yuchenl3@andrew.cmu.edu> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
d4d64daa1e
commit
ad6dfb6220
@ -1,13 +1,13 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
from typing import Any, Dict, List, Mapping, Optional, cast
|
||||
|
||||
import requests
|
||||
from langchain_core.pydantic_v1 import Extra, Field, root_validator
|
||||
from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -15,8 +15,9 @@ logger = logging.getLogger(__name__)
|
||||
class CerebriumAI(LLM):
|
||||
"""CerebriumAI large language models.
|
||||
|
||||
To use, you should have the ``cerebrium`` python package installed, and the
|
||||
environment variable ``CEREBRIUMAI_API_KEY`` set with your API key.
|
||||
To use, you should have the ``cerebrium`` python package installed.
|
||||
You should also have the environment variable ``CEREBRIUMAI_API_KEY``
|
||||
set with your API key or pass it as a named argument in the constructor.
|
||||
|
||||
Any parameters that are valid to be passed to the call can be passed
|
||||
in, even if not explicitly saved on this class.
|
||||
@ -25,7 +26,7 @@ class CerebriumAI(LLM):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import CerebriumAI
|
||||
cerebrium = CerebriumAI(endpoint_url="")
|
||||
cerebrium = CerebriumAI(endpoint_url="", cerebriumai_api_key="my-api-key")
|
||||
|
||||
"""
|
||||
|
||||
@ -36,7 +37,7 @@ class CerebriumAI(LLM):
|
||||
"""Holds any model parameters valid for `create` call not
|
||||
explicitly specified."""
|
||||
|
||||
cerebriumai_api_key: Optional[str] = None
|
||||
cerebriumai_api_key: Optional[SecretStr] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic config."""
|
||||
@ -64,8 +65,8 @@ class CerebriumAI(LLM):
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
cerebriumai_api_key = get_from_dict_or_env(
|
||||
values, "cerebriumai_api_key", "CEREBRIUMAI_API_KEY"
|
||||
cerebriumai_api_key = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "cerebriumai_api_key", "CEREBRIUMAI_API_KEY")
|
||||
)
|
||||
values["cerebriumai_api_key"] = cerebriumai_api_key
|
||||
return values
|
||||
@ -91,7 +92,9 @@ class CerebriumAI(LLM):
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
headers: Dict = {
|
||||
"Authorization": self.cerebriumai_api_key,
|
||||
"Authorization": cast(
|
||||
SecretStr, self.cerebriumai_api_key
|
||||
).get_secret_value(),
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
params = self.model_kwargs or {}
|
||||
|
33
libs/langchain/tests/unit_tests/llms/test_cerebriumai.py
Normal file
33
libs/langchain/tests/unit_tests/llms/test_cerebriumai.py
Normal file
@ -0,0 +1,33 @@
|
||||
"""Test CerebriumAI llm"""
|
||||
|
||||
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pytest import CaptureFixture, MonkeyPatch
|
||||
|
||||
from langchain.llms.cerebriumai import CerebriumAI
|
||||
|
||||
|
||||
def test_api_key_is_secret_string() -> None:
|
||||
llm = CerebriumAI(cerebriumai_api_key="test-cerebriumai-api-key")
|
||||
assert isinstance(llm.cerebriumai_api_key, SecretStr)
|
||||
|
||||
|
||||
def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> None:
|
||||
llm = CerebriumAI(cerebriumai_api_key="secret-api-key")
|
||||
print(llm.cerebriumai_api_key, end="")
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
||||
assert repr(llm.cerebriumai_api_key) == "SecretStr('**********')"
|
||||
|
||||
|
||||
def test_api_key_masked_when_passed_from_env(
|
||||
monkeypatch: MonkeyPatch, capsys: CaptureFixture
|
||||
) -> None:
|
||||
monkeypatch.setenv("CEREBRIUMAI_API_KEY", "secret-api-key")
|
||||
llm = CerebriumAI()
|
||||
print(llm.cerebriumai_api_key, end="")
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
||||
assert repr(llm.cerebriumai_api_key) == "SecretStr('**********')"
|
Loading…
Reference in New Issue
Block a user