mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 00:23:25 +00:00
feat: mask api key for cerebriumai llm (#14272)
- **Description:** Masking API key for CerebriumAI LLM to protect user secrets. - **Issue:** #12165 - **Dependencies:** None - **Tag maintainer:** @eyurtsev --------- Signed-off-by: Yuchen Liang <yuchenl3@andrew.cmu.edu> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
d4d64daa1e
commit
ad6dfb6220
@ -1,13 +1,13 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
from typing import Any, Dict, List, Mapping, Optional, cast
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from langchain_core.pydantic_v1 import Extra, Field, root_validator
|
from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.llms.utils import enforce_stop_tokens
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -15,8 +15,9 @@ logger = logging.getLogger(__name__)
|
|||||||
class CerebriumAI(LLM):
|
class CerebriumAI(LLM):
|
||||||
"""CerebriumAI large language models.
|
"""CerebriumAI large language models.
|
||||||
|
|
||||||
To use, you should have the ``cerebrium`` python package installed, and the
|
To use, you should have the ``cerebrium`` python package installed.
|
||||||
environment variable ``CEREBRIUMAI_API_KEY`` set with your API key.
|
You should also have the environment variable ``CEREBRIUMAI_API_KEY``
|
||||||
|
set with your API key or pass it as a named argument in the constructor.
|
||||||
|
|
||||||
Any parameters that are valid to be passed to the call can be passed
|
Any parameters that are valid to be passed to the call can be passed
|
||||||
in, even if not explicitly saved on this class.
|
in, even if not explicitly saved on this class.
|
||||||
@ -25,7 +26,7 @@ class CerebriumAI(LLM):
|
|||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain.llms import CerebriumAI
|
from langchain.llms import CerebriumAI
|
||||||
cerebrium = CerebriumAI(endpoint_url="")
|
cerebrium = CerebriumAI(endpoint_url="", cerebriumai_api_key="my-api-key")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -36,7 +37,7 @@ class CerebriumAI(LLM):
|
|||||||
"""Holds any model parameters valid for `create` call not
|
"""Holds any model parameters valid for `create` call not
|
||||||
explicitly specified."""
|
explicitly specified."""
|
||||||
|
|
||||||
cerebriumai_api_key: Optional[str] = None
|
cerebriumai_api_key: Optional[SecretStr] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic config."""
|
"""Configuration for this pydantic config."""
|
||||||
@ -64,8 +65,8 @@ class CerebriumAI(LLM):
|
|||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
cerebriumai_api_key = get_from_dict_or_env(
|
cerebriumai_api_key = convert_to_secret_str(
|
||||||
values, "cerebriumai_api_key", "CEREBRIUMAI_API_KEY"
|
get_from_dict_or_env(values, "cerebriumai_api_key", "CEREBRIUMAI_API_KEY")
|
||||||
)
|
)
|
||||||
values["cerebriumai_api_key"] = cerebriumai_api_key
|
values["cerebriumai_api_key"] = cerebriumai_api_key
|
||||||
return values
|
return values
|
||||||
@ -91,7 +92,9 @@ class CerebriumAI(LLM):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
headers: Dict = {
|
headers: Dict = {
|
||||||
"Authorization": self.cerebriumai_api_key,
|
"Authorization": cast(
|
||||||
|
SecretStr, self.cerebriumai_api_key
|
||||||
|
).get_secret_value(),
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
params = self.model_kwargs or {}
|
params = self.model_kwargs or {}
|
||||||
|
33
libs/langchain/tests/unit_tests/llms/test_cerebriumai.py
Normal file
33
libs/langchain/tests/unit_tests/llms/test_cerebriumai.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
"""Test CerebriumAI llm"""
|
||||||
|
|
||||||
|
|
||||||
|
from langchain_core.pydantic_v1 import SecretStr
|
||||||
|
from pytest import CaptureFixture, MonkeyPatch
|
||||||
|
|
||||||
|
from langchain.llms.cerebriumai import CerebriumAI
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_key_is_secret_string() -> None:
|
||||||
|
llm = CerebriumAI(cerebriumai_api_key="test-cerebriumai-api-key")
|
||||||
|
assert isinstance(llm.cerebriumai_api_key, SecretStr)
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> None:
|
||||||
|
llm = CerebriumAI(cerebriumai_api_key="secret-api-key")
|
||||||
|
print(llm.cerebriumai_api_key, end="")
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
|
||||||
|
assert captured.out == "**********"
|
||||||
|
assert repr(llm.cerebriumai_api_key) == "SecretStr('**********')"
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_key_masked_when_passed_from_env(
|
||||||
|
monkeypatch: MonkeyPatch, capsys: CaptureFixture
|
||||||
|
) -> None:
|
||||||
|
monkeypatch.setenv("CEREBRIUMAI_API_KEY", "secret-api-key")
|
||||||
|
llm = CerebriumAI()
|
||||||
|
print(llm.cerebriumai_api_key, end="")
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
|
||||||
|
assert captured.out == "**********"
|
||||||
|
assert repr(llm.cerebriumai_api_key) == "SecretStr('**********')"
|
Loading…
Reference in New Issue
Block a user