community[patch]: Update root_validators to use pre=True or pre=False (#23731)

Update root_validators in preparation for pydantic 2 migration.
This commit is contained in:
Eugene Yurtsev 2024-07-01 16:10:15 -04:00 committed by GitHub
parent 6019147b66
commit 5d2262af34
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 13 additions and 10 deletions

View File

@ -209,7 +209,7 @@ class OpenAIAssistantV2Runnable(OpenAIAssistantRunnable):
as_agent: bool = False
"""Use as a LangChain agent, compatible with the AgentExecutor."""
@root_validator()
@root_validator(pre=False, skip_on_failure=True)
def validate_async_client(cls, values: dict) -> dict:
if values["async_client"] is None:
import openai

View File

@ -59,7 +59,7 @@ class LLMRequestsChain(Chain):
"""
return [self.output_key]
@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:

View File

@ -101,7 +101,7 @@ class ChatAnyscale(ChatOpenAI):
return {model["id"] for model in models_response.json()["data"]}
@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: dict) -> dict:
"""Validate that api key and python package exists in environment."""
values["anyscale_api_key"] = convert_to_secret_str(

View File

@ -116,7 +116,7 @@ class ChatCoze(BaseChatModel):
allow_population_by_field_name = True
@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
values["coze_api_base"] = get_from_dict_or_env(
values,

View File

@ -75,7 +75,7 @@ class ChatDappierAI(BaseChatModel):
extra = Extra.forbid
@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment."""
values["dappier_api_key"] = convert_to_secret_str(

View File

@ -278,8 +278,8 @@ class ChatDeepInfra(BaseChatModel):
return await _completion_with_retry(**kwargs)
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
@root_validator(pre=True)
def init_defaults(cls, values: Dict) -> Dict:
"""Validate api key, python package exists, temperature, top_p, and top_k."""
# For compatibility with LiteLLM
api_key = get_from_dict_or_env(
@ -294,7 +294,10 @@ class ChatDeepInfra(BaseChatModel):
"DEEPINFRA_API_TOKEN",
default=api_key,
)
return values
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")

View File

@ -108,7 +108,7 @@ class ErnieBotChat(BaseChatModel):
_lock = threading.Lock()
@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
values["ernie_api_base"] = get_from_dict_or_env(
values, "ernie_api_base", "ERNIE_API_BASE", "https://aip.baidubce.com"

View File

@ -112,7 +112,7 @@ class ChatFireworks(BaseChatModel):
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "fireworks"]
@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key in environment."""
try:

View File

@ -76,7 +76,7 @@ class ChatHuggingFace(BaseChatModel):
else self.tokenizer
)
@root_validator()
@root_validator(pre=False, skip_on_failure=True)
def validate_llm(cls, values: dict) -> dict:
if not isinstance(
values["llm"],