mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 16:13:25 +00:00
Refactor: use SecretStr for palm chat-model (#15100)
This commit is contained in:
parent
b6952d41e5
commit
32e6e9de13
@ -20,8 +20,8 @@ from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
@ -233,7 +233,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
||||
client: Any #: :meta private:
|
||||
model_name: str = "models/chat-bison-001"
|
||||
"""Model name to use."""
|
||||
google_api_key: Optional[str] = None
|
||||
google_api_key: Optional[SecretStr] = None
|
||||
temperature: Optional[float] = None
|
||||
"""Run inference with this temperature. Must by in the closed
|
||||
interval [0.0, 1.0]."""
|
||||
@ -263,13 +263,13 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate api key, python package exists, temperature, top_p, and top_k."""
|
||||
google_api_key = get_from_dict_or_env(
|
||||
values, "google_api_key", "GOOGLE_API_KEY"
|
||||
google_api_key = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "google_api_key", "GOOGLE_API_KEY")
|
||||
)
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
|
||||
genai.configure(api_key=google_api_key)
|
||||
genai.configure(api_key=google_api_key.get_secret_value())
|
||||
except ImportError:
|
||||
raise ChatGooglePalmError(
|
||||
"Could not import google.generativeai python package. "
|
||||
|
Loading…
Reference in New Issue
Block a user