Merge pull request #13945

* feat: mask api key for nlpcloud
This commit is contained in:
卢靖轩 2023-11-29 10:16:36 +08:00 committed by GitHub
parent 85bb3a418c
commit aff1dba252
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 6 deletions

View File

@ -1,10 +1,10 @@
from typing import Any, Dict, List, Mapping, Optional
from langchain_core.pydantic_v1 import Extra, root_validator
from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
class NLPCloud(LLM):
@ -50,7 +50,7 @@ class NLPCloud(LLM):
num_return_sequences: int = 1
"""How many completions to generate for each prompt."""
nlpcloud_api_key: Optional[str] = None
nlpcloud_api_key: Optional[SecretStr] = None
class Config:
"""Configuration for this pydantic object."""
@ -60,15 +60,15 @@ class NLPCloud(LLM):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
nlpcloud_api_key = get_from_dict_or_env(
values, "nlpcloud_api_key", "NLPCLOUD_API_KEY"
values["nlpcloud_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "nlpcloud_api_key", "NLPCLOUD_API_KEY")
)
try:
import nlpcloud
values["client"] = nlpcloud.Client(
values["model_name"],
nlpcloud_api_key,
values["nlpcloud_api_key"].get_secret_value(),
gpu=values["gpu"],
lang=values["lang"],
)

View File

@ -1,9 +1,13 @@
"""Test NLPCloud API wrapper."""
from pathlib import Path
from typing import cast
from pytest import CaptureFixture, MonkeyPatch
from langchain.llms.loading import load_llm
from langchain.llms.nlpcloud import NLPCloud
from langchain.pydantic_v1 import SecretStr
from tests.integration_tests.llms.utils import assert_llm_equality
@ -20,3 +24,20 @@ def test_saving_loading_llm(tmp_path: Path) -> None:
llm.save(file_path=tmp_path / "nlpcloud.yaml")
loaded_llm = load_llm(tmp_path / "nlpcloud.yaml")
assert_llm_equality(llm, loaded_llm)
def test_nlpcloud_api_key(monkeypatch: MonkeyPatch, capsys: CaptureFixture) -> None:
"""Test that nlpcloud api key is a secret key."""
# test initialization from init
assert isinstance(NLPCloud(nlpcloud_api_key="1").nlpcloud_api_key, SecretStr)
monkeypatch.setenv("NLPCLOUD_API_KEY", "secret-api-key")
llm = NLPCloud()
assert isinstance(llm.nlpcloud_api_key, SecretStr)
assert cast(SecretStr, llm.nlpcloud_api_key).get_secret_value() == "secret-api-key"
print(llm.nlpcloud_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"