From 5d2262af341cbc97a454a47d8518632176dccb61 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 1 Jul 2024 16:10:15 -0400 Subject: [PATCH] community[patch]: Update root_validators to use pre=True or pre=False (#23731) Update root_validators in preparation for pydantic 2 migration. --- .../langchain_community/agents/openai_assistant/base.py | 2 +- libs/community/langchain_community/chains/llm_requests.py | 2 +- libs/community/langchain_community/chat_models/anyscale.py | 2 +- libs/community/langchain_community/chat_models/coze.py | 2 +- libs/community/langchain_community/chat_models/dappier.py | 2 +- .../community/langchain_community/chat_models/deepinfra.py | 7 +++++-- libs/community/langchain_community/chat_models/ernie.py | 2 +- .../community/langchain_community/chat_models/fireworks.py | 2 +- .../langchain_community/chat_models/huggingface.py | 2 +- 9 files changed, 13 insertions(+), 10 deletions(-) diff --git a/libs/community/langchain_community/agents/openai_assistant/base.py b/libs/community/langchain_community/agents/openai_assistant/base.py index 09b6d615b9c..6fb86486ca7 100644 --- a/libs/community/langchain_community/agents/openai_assistant/base.py +++ b/libs/community/langchain_community/agents/openai_assistant/base.py @@ -209,7 +209,7 @@ class OpenAIAssistantV2Runnable(OpenAIAssistantRunnable): as_agent: bool = False """Use as a LangChain agent, compatible with the AgentExecutor.""" - @root_validator() + @root_validator(pre=False, skip_on_failure=True) def validate_async_client(cls, values: dict) -> dict: if values["async_client"] is None: import openai diff --git a/libs/community/langchain_community/chains/llm_requests.py b/libs/community/langchain_community/chains/llm_requests.py index 304a74fc09c..c71dcbafe85 100644 --- a/libs/community/langchain_community/chains/llm_requests.py +++ b/libs/community/langchain_community/chains/llm_requests.py @@ -59,7 +59,7 @@ class LLMRequestsChain(Chain): """ return [self.output_key] - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" try: diff --git a/libs/community/langchain_community/chat_models/anyscale.py b/libs/community/langchain_community/chat_models/anyscale.py index 2acdb49b5b8..df89dc4fb05 100644 --- a/libs/community/langchain_community/chat_models/anyscale.py +++ b/libs/community/langchain_community/chat_models/anyscale.py @@ -101,7 +101,7 @@ class ChatAnyscale(ChatOpenAI): return {model["id"] for model in models_response.json()["data"]} - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: dict) -> dict: """Validate that api key and python package exists in environment.""" values["anyscale_api_key"] = convert_to_secret_str( diff --git a/libs/community/langchain_community/chat_models/coze.py b/libs/community/langchain_community/chat_models/coze.py index 733719be135..fd330292586 100644 --- a/libs/community/langchain_community/chat_models/coze.py +++ b/libs/community/langchain_community/chat_models/coze.py @@ -116,7 +116,7 @@ class ChatCoze(BaseChatModel): allow_population_by_field_name = True - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: values["coze_api_base"] = get_from_dict_or_env( values, diff --git a/libs/community/langchain_community/chat_models/dappier.py b/libs/community/langchain_community/chat_models/dappier.py index 3468feebd7c..e78801914ed 100644 --- a/libs/community/langchain_community/chat_models/dappier.py +++ b/libs/community/langchain_community/chat_models/dappier.py @@ -75,7 +75,7 @@ class ChatDappierAI(BaseChatModel): extra = Extra.forbid - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key exists in environment.""" values["dappier_api_key"] = convert_to_secret_str( diff --git a/libs/community/langchain_community/chat_models/deepinfra.py b/libs/community/langchain_community/chat_models/deepinfra.py index 32be0867a0a..9bad6b251bf 100644 --- a/libs/community/langchain_community/chat_models/deepinfra.py +++ b/libs/community/langchain_community/chat_models/deepinfra.py @@ -278,8 +278,8 @@ class ChatDeepInfra(BaseChatModel): return await _completion_with_retry(**kwargs) - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: + @root_validator(pre=True) + def init_defaults(cls, values: Dict) -> Dict: """Validate api key, python package exists, temperature, top_p, and top_k.""" # For compatibility with LiteLLM api_key = get_from_dict_or_env( @@ -294,7 +294,10 @@ class ChatDeepInfra(BaseChatModel): "DEEPINFRA_API_TOKEN", default=api_key, ) + return values + @root_validator(pre=False, skip_on_failure=True) + def validate_environment(cls, values: Dict) -> Dict: if values["temperature"] is not None and not 0 <= values["temperature"] <= 1: raise ValueError("temperature must be in the range [0.0, 1.0]") diff --git a/libs/community/langchain_community/chat_models/ernie.py b/libs/community/langchain_community/chat_models/ernie.py index 8766038e3b1..b8681fb867b 100644 --- a/libs/community/langchain_community/chat_models/ernie.py +++ b/libs/community/langchain_community/chat_models/ernie.py @@ -108,7 +108,7 @@ class ErnieBotChat(BaseChatModel): _lock = threading.Lock() - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: values["ernie_api_base"] = get_from_dict_or_env( values, "ernie_api_base", "ERNIE_API_BASE", "https://aip.baidubce.com" diff --git a/libs/community/langchain_community/chat_models/fireworks.py b/libs/community/langchain_community/chat_models/fireworks.py index 2e434f1ae63..3147d469fcc 100644 --- a/libs/community/langchain_community/chat_models/fireworks.py +++ b/libs/community/langchain_community/chat_models/fireworks.py @@ -112,7 +112,7 @@ class ChatFireworks(BaseChatModel): """Get the namespace of the langchain object.""" return ["langchain", "chat_models", "fireworks"] - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key in environment.""" try: diff --git a/libs/community/langchain_community/chat_models/huggingface.py b/libs/community/langchain_community/chat_models/huggingface.py index 68aaabaddd5..eb22b9ea4b7 100644 --- a/libs/community/langchain_community/chat_models/huggingface.py +++ b/libs/community/langchain_community/chat_models/huggingface.py @@ -76,7 +76,7 @@ class ChatHuggingFace(BaseChatModel): else self.tokenizer ) - @root_validator() + @root_validator(pre=False, skip_on_failure=True) def validate_llm(cls, values: dict) -> dict: if not isinstance( values["llm"],