mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
parent
85bb3a418c
commit
aff1dba252
@ -1,10 +1,10 @@
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from langchain_core.pydantic_v1 import Extra, root_validator
|
||||
from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
|
||||
|
||||
class NLPCloud(LLM):
|
||||
@ -50,7 +50,7 @@ class NLPCloud(LLM):
|
||||
num_return_sequences: int = 1
|
||||
"""How many completions to generate for each prompt."""
|
||||
|
||||
nlpcloud_api_key: Optional[str] = None
|
||||
nlpcloud_api_key: Optional[SecretStr] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -60,15 +60,15 @@ class NLPCloud(LLM):
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
nlpcloud_api_key = get_from_dict_or_env(
|
||||
values, "nlpcloud_api_key", "NLPCLOUD_API_KEY"
|
||||
values["nlpcloud_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "nlpcloud_api_key", "NLPCLOUD_API_KEY")
|
||||
)
|
||||
try:
|
||||
import nlpcloud
|
||||
|
||||
values["client"] = nlpcloud.Client(
|
||||
values["model_name"],
|
||||
nlpcloud_api_key,
|
||||
values["nlpcloud_api_key"].get_secret_value(),
|
||||
gpu=values["gpu"],
|
||||
lang=values["lang"],
|
||||
)
|
||||
|
@ -1,9 +1,13 @@
|
||||
"""Test NLPCloud API wrapper."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
from pytest import CaptureFixture, MonkeyPatch
|
||||
|
||||
from langchain.llms.loading import load_llm
|
||||
from langchain.llms.nlpcloud import NLPCloud
|
||||
from langchain.pydantic_v1 import SecretStr
|
||||
from tests.integration_tests.llms.utils import assert_llm_equality
|
||||
|
||||
|
||||
@ -20,3 +24,20 @@ def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
llm.save(file_path=tmp_path / "nlpcloud.yaml")
|
||||
loaded_llm = load_llm(tmp_path / "nlpcloud.yaml")
|
||||
assert_llm_equality(llm, loaded_llm)
|
||||
|
||||
|
||||
def test_nlpcloud_api_key(monkeypatch: MonkeyPatch, capsys: CaptureFixture) -> None:
|
||||
"""Test that nlpcloud api key is a secret key."""
|
||||
# test initialization from init
|
||||
assert isinstance(NLPCloud(nlpcloud_api_key="1").nlpcloud_api_key, SecretStr)
|
||||
|
||||
monkeypatch.setenv("NLPCLOUD_API_KEY", "secret-api-key")
|
||||
llm = NLPCloud()
|
||||
assert isinstance(llm.nlpcloud_api_key, SecretStr)
|
||||
|
||||
assert cast(SecretStr, llm.nlpcloud_api_key).get_secret_value() == "secret-api-key"
|
||||
|
||||
print(llm.nlpcloud_api_key, end="")
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
||||
|
Loading…
Reference in New Issue
Block a user