mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 07:09:31 +00:00
parent
85bb3a418c
commit
aff1dba252
@ -1,10 +1,10 @@
|
|||||||
from typing import Any, Dict, List, Mapping, Optional
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
from langchain_core.pydantic_v1 import Extra, root_validator
|
from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
class NLPCloud(LLM):
|
class NLPCloud(LLM):
|
||||||
@ -50,7 +50,7 @@ class NLPCloud(LLM):
|
|||||||
num_return_sequences: int = 1
|
num_return_sequences: int = 1
|
||||||
"""How many completions to generate for each prompt."""
|
"""How many completions to generate for each prompt."""
|
||||||
|
|
||||||
nlpcloud_api_key: Optional[str] = None
|
nlpcloud_api_key: Optional[SecretStr] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -60,15 +60,15 @@ class NLPCloud(LLM):
|
|||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
nlpcloud_api_key = get_from_dict_or_env(
|
values["nlpcloud_api_key"] = convert_to_secret_str(
|
||||||
values, "nlpcloud_api_key", "NLPCLOUD_API_KEY"
|
get_from_dict_or_env(values, "nlpcloud_api_key", "NLPCLOUD_API_KEY")
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
import nlpcloud
|
import nlpcloud
|
||||||
|
|
||||||
values["client"] = nlpcloud.Client(
|
values["client"] = nlpcloud.Client(
|
||||||
values["model_name"],
|
values["model_name"],
|
||||||
nlpcloud_api_key,
|
values["nlpcloud_api_key"].get_secret_value(),
|
||||||
gpu=values["gpu"],
|
gpu=values["gpu"],
|
||||||
lang=values["lang"],
|
lang=values["lang"],
|
||||||
)
|
)
|
||||||
|
@ -1,9 +1,13 @@
|
|||||||
"""Test NLPCloud API wrapper."""
|
"""Test NLPCloud API wrapper."""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from pytest import CaptureFixture, MonkeyPatch
|
||||||
|
|
||||||
from langchain.llms.loading import load_llm
|
from langchain.llms.loading import load_llm
|
||||||
from langchain.llms.nlpcloud import NLPCloud
|
from langchain.llms.nlpcloud import NLPCloud
|
||||||
|
from langchain.pydantic_v1 import SecretStr
|
||||||
from tests.integration_tests.llms.utils import assert_llm_equality
|
from tests.integration_tests.llms.utils import assert_llm_equality
|
||||||
|
|
||||||
|
|
||||||
@ -20,3 +24,20 @@ def test_saving_loading_llm(tmp_path: Path) -> None:
|
|||||||
llm.save(file_path=tmp_path / "nlpcloud.yaml")
|
llm.save(file_path=tmp_path / "nlpcloud.yaml")
|
||||||
loaded_llm = load_llm(tmp_path / "nlpcloud.yaml")
|
loaded_llm = load_llm(tmp_path / "nlpcloud.yaml")
|
||||||
assert_llm_equality(llm, loaded_llm)
|
assert_llm_equality(llm, loaded_llm)
|
||||||
|
|
||||||
|
|
||||||
|
def test_nlpcloud_api_key(monkeypatch: MonkeyPatch, capsys: CaptureFixture) -> None:
|
||||||
|
"""Test that nlpcloud api key is a secret key."""
|
||||||
|
# test initialization from init
|
||||||
|
assert isinstance(NLPCloud(nlpcloud_api_key="1").nlpcloud_api_key, SecretStr)
|
||||||
|
|
||||||
|
monkeypatch.setenv("NLPCLOUD_API_KEY", "secret-api-key")
|
||||||
|
llm = NLPCloud()
|
||||||
|
assert isinstance(llm.nlpcloud_api_key, SecretStr)
|
||||||
|
|
||||||
|
assert cast(SecretStr, llm.nlpcloud_api_key).get_secret_value() == "secret-api-key"
|
||||||
|
|
||||||
|
print(llm.nlpcloud_api_key, end="")
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
|
||||||
|
assert captured.out == "**********"
|
||||||
|
Loading…
Reference in New Issue
Block a user