Compare commits

...

1 Commits

Author SHA1 Message Date
Eugene Yurtsev
f2a1c2726c Update @root_validators 2024-07-01 16:27:53 -04:00
5 changed files with 9 additions and 6 deletions

View File

@@ -135,7 +135,7 @@ class ChatSnowflakeCortex(BaseChatModel):
)
return values
@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
try:
from snowflake.snowpark import Session

View File

@@ -37,7 +37,7 @@ class SolarChat(SolarCommon, ChatOpenAI):
arbitrary_types_allowed = True
extra = "ignore"
@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the environment is set up correctly."""
values["solar_api_key"] = get_from_dict_or_env(

View File

@@ -431,7 +431,7 @@ class ChatTongyi(BaseChatModel):
"""Return type of llm."""
return "tongyi"
@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["dashscope_api_key"] = convert_to_secret_str(

View File

@@ -203,6 +203,9 @@ def _get_question(messages: List[BaseMessage]) -> HumanMessage:
return question
_DEFAULT_MODEL_NAME = "chat-bison"
@deprecated(
since="0.0.12",
removal="0.3.0",
@@ -211,7 +214,7 @@ def _get_question(messages: List[BaseMessage]) -> HumanMessage:
class ChatVertexAI(_VertexAICommon, BaseChatModel):
"""`Vertex AI` Chat large language models API."""
model_name: str = "chat-bison"
model_name: str = _DEFAULT_MODEL_NAME
"Underlying model name."
examples: Optional[List[BaseMessage]] = None
@@ -227,7 +230,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""
is_gemini = is_gemini_model(values["model_name"])
is_gemini = is_gemini_model(values.get("model_name", _DEFAULT_MODEL_NAME))
cls._try_init_vertexai(values)
try:
from vertexai.language_models import ChatModel, CodeChatModel

View File

@@ -165,7 +165,7 @@ class ChatYuan2(BaseChatModel):
values["model_kwargs"] = extra
return values
@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["yuan2_api_key"] = get_from_dict_or_env(