Refactor: use SecretStr for palm chat-model (#15100)

This commit is contained in:
chyroc 2024-01-02 07:21:41 +08:00 committed by GitHub
parent b6952d41e5
commit 32e6e9de13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -20,8 +20,8 @@ from langchain_core.outputs import (
ChatGeneration, ChatGeneration,
ChatResult, ChatResult,
) )
from langchain_core.pydantic_v1 import BaseModel, root_validator from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
from langchain_core.utils import get_from_dict_or_env from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from tenacity import ( from tenacity import (
before_sleep_log, before_sleep_log,
retry, retry,
@ -233,7 +233,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
client: Any #: :meta private: client: Any #: :meta private:
model_name: str = "models/chat-bison-001" model_name: str = "models/chat-bison-001"
"""Model name to use.""" """Model name to use."""
google_api_key: Optional[str] = None google_api_key: Optional[SecretStr] = None
temperature: Optional[float] = None temperature: Optional[float] = None
"""Run inference with this temperature. Must by in the closed """Run inference with this temperature. Must by in the closed
interval [0.0, 1.0].""" interval [0.0, 1.0]."""
@ -263,13 +263,13 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists, temperature, top_p, and top_k.""" """Validate api key, python package exists, temperature, top_p, and top_k."""
google_api_key = get_from_dict_or_env( google_api_key = convert_to_secret_str(
values, "google_api_key", "GOOGLE_API_KEY" get_from_dict_or_env(values, "google_api_key", "GOOGLE_API_KEY")
) )
try: try:
import google.generativeai as genai import google.generativeai as genai
genai.configure(api_key=google_api_key) genai.configure(api_key=google_api_key.get_secret_value())
except ImportError: except ImportError:
raise ChatGooglePalmError( raise ChatGooglePalmError(
"Could not import google.generativeai python package. " "Could not import google.generativeai python package. "