mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 13:40:46 +00:00
Mask API key for Minimax LLM (#14309)
- **Description:** Added masking for the API key for Minimax LLM + tests inspired by https://github.com/langchain-ai/langchain/pull/12418. - **Issue:** the issue # fixes https://github.com/langchain-ai/langchain/issues/12165 - **Dependencies:** this fix is dependent on Minimax instantiation fix which is introduced in https://github.com/langchain-ai/langchain/pull/13439, so merge this one after. - **Tag maintainer:** @eyurtsev --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
29e993a5f2
commit
d22c13ec48
@ -10,14 +10,14 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.llms.utils import enforce_stop_tokens
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -27,7 +27,7 @@ class _MinimaxEndpointClient(BaseModel):
|
|||||||
|
|
||||||
host: str
|
host: str
|
||||||
group_id: str
|
group_id: str
|
||||||
api_key: str
|
api_key: SecretStr
|
||||||
api_url: str
|
api_url: str
|
||||||
|
|
||||||
@root_validator(pre=True, allow_reuse=True)
|
@root_validator(pre=True, allow_reuse=True)
|
||||||
@ -40,7 +40,7 @@ class _MinimaxEndpointClient(BaseModel):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
def post(self, request: Any) -> Any:
|
def post(self, request: Any) -> Any:
|
||||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
headers = {"Authorization": f"Bearer {self.api_key.get_secret_value()}"}
|
||||||
response = requests.post(self.api_url, headers=headers, json=request)
|
response = requests.post(self.api_url, headers=headers, json=request)
|
||||||
# TODO: error handling and automatic retries
|
# TODO: error handling and automatic retries
|
||||||
if not response.ok:
|
if not response.ok:
|
||||||
@ -56,7 +56,7 @@ class _MinimaxEndpointClient(BaseModel):
|
|||||||
class MinimaxCommon(BaseModel):
|
class MinimaxCommon(BaseModel):
|
||||||
"""Common parameters for Minimax large language models."""
|
"""Common parameters for Minimax large language models."""
|
||||||
|
|
||||||
_client: Any = None
|
_client: _MinimaxEndpointClient
|
||||||
model: str = "abab5.5-chat"
|
model: str = "abab5.5-chat"
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
max_tokens: int = 256
|
max_tokens: int = 256
|
||||||
@ -69,13 +69,13 @@ class MinimaxCommon(BaseModel):
|
|||||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||||
minimax_api_host: Optional[str] = None
|
minimax_api_host: Optional[str] = None
|
||||||
minimax_group_id: Optional[str] = None
|
minimax_group_id: Optional[str] = None
|
||||||
minimax_api_key: Optional[str] = None
|
minimax_api_key: Optional[SecretStr] = None
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
values["minimax_api_key"] = get_from_dict_or_env(
|
values["minimax_api_key"] = convert_to_secret_str(
|
||||||
values, "minimax_api_key", "MINIMAX_API_KEY"
|
get_from_dict_or_env(values, "minimax_api_key", "MINIMAX_API_KEY")
|
||||||
)
|
)
|
||||||
values["minimax_group_id"] = get_from_dict_or_env(
|
values["minimax_group_id"] = get_from_dict_or_env(
|
||||||
values, "minimax_group_id", "MINIMAX_GROUP_ID"
|
values, "minimax_group_id", "MINIMAX_GROUP_ID"
|
||||||
@ -87,6 +87,11 @@ class MinimaxCommon(BaseModel):
|
|||||||
"MINIMAX_API_HOST",
|
"MINIMAX_API_HOST",
|
||||||
default="https://api.minimax.chat",
|
default="https://api.minimax.chat",
|
||||||
)
|
)
|
||||||
|
values["_client"] = _MinimaxEndpointClient(
|
||||||
|
host=values["minimax_api_host"],
|
||||||
|
api_key=values["minimax_api_key"],
|
||||||
|
group_id=values["minimax_group_id"],
|
||||||
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -110,14 +115,6 @@ class MinimaxCommon(BaseModel):
|
|||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
return "minimax"
|
return "minimax"
|
||||||
|
|
||||||
def __init__(self, **data: Any):
|
|
||||||
super().__init__(**data)
|
|
||||||
self._client = _MinimaxEndpointClient(
|
|
||||||
host=self.minimax_api_host,
|
|
||||||
api_key=self.minimax_api_key,
|
|
||||||
group_id=self.minimax_group_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Minimax(MinimaxCommon, LLM):
|
class Minimax(MinimaxCommon, LLM):
|
||||||
"""Wrapper around Minimax large language models.
|
"""Wrapper around Minimax large language models.
|
||||||
|
42
libs/langchain/tests/unit_tests/llms/test_minimax.py
Normal file
42
libs/langchain/tests/unit_tests/llms/test_minimax.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
"""Test Minimax llm"""
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from langchain_core.pydantic_v1 import SecretStr
|
||||||
|
from pytest import CaptureFixture, MonkeyPatch
|
||||||
|
|
||||||
|
from langchain.llms.minimax import Minimax
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_key_is_secret_string() -> None:
|
||||||
|
llm = Minimax(minimax_api_key="secret-api-key", minimax_group_id="group_id")
|
||||||
|
assert isinstance(llm.minimax_api_key, SecretStr)
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_key_masked_when_passed_from_env(
|
||||||
|
monkeypatch: MonkeyPatch, capsys: CaptureFixture
|
||||||
|
) -> None:
|
||||||
|
"""Test initialization with an API key provided via an env variable"""
|
||||||
|
monkeypatch.setenv("MINIMAX_API_KEY", "secret-api-key")
|
||||||
|
monkeypatch.setenv("MINIMAX_GROUP_ID", "group_id")
|
||||||
|
llm = Minimax()
|
||||||
|
print(llm.minimax_api_key, end="")
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
|
||||||
|
assert captured.out == "**********"
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_key_masked_when_passed_via_constructor(
|
||||||
|
capsys: CaptureFixture,
|
||||||
|
) -> None:
|
||||||
|
"""Test initialization with an API key provided via the initializer"""
|
||||||
|
llm = Minimax(minimax_api_key="secret-api-key", minimax_group_id="group_id")
|
||||||
|
print(llm.minimax_api_key, end="")
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
|
||||||
|
assert captured.out == "**********"
|
||||||
|
|
||||||
|
|
||||||
|
def test_uses_actual_secret_value_from_secretstr() -> None:
|
||||||
|
"""Test that actual secret is retrieved using `.get_secret_value()`."""
|
||||||
|
llm = Minimax(minimax_api_key="secret-api-key", minimax_group_id="group_id")
|
||||||
|
assert cast(SecretStr, llm.minimax_api_key).get_secret_value() == "secret-api-key"
|
Loading…
Reference in New Issue
Block a user