diff --git a/libs/community/langchain_community/agent_toolkits/file_management/toolkit.py b/libs/community/langchain_community/agent_toolkits/file_management/toolkit.py index 866bd59c0f5..bfcb77ba2c0 100644 --- a/libs/community/langchain_community/agent_toolkits/file_management/toolkit.py +++ b/libs/community/langchain_community/agent_toolkits/file_management/toolkit.py @@ -63,7 +63,7 @@ class FileManagementToolkit(BaseToolkit): selected_tools: Optional[List[str]] = None """If provided, only provide the selected tools. Defaults to all.""" - @root_validator + @root_validator(pre=True) def validate_tools(cls, values: dict) -> dict: selected_tools = values.get("selected_tools") or [] for tool_name in selected_tools: diff --git a/libs/community/langchain_community/agent_toolkits/playwright/toolkit.py b/libs/community/langchain_community/agent_toolkits/playwright/toolkit.py index 40b4c411491..28ccf1aabb1 100644 --- a/libs/community/langchain_community/agent_toolkits/playwright/toolkit.py +++ b/libs/community/langchain_community/agent_toolkits/playwright/toolkit.py @@ -74,7 +74,7 @@ class PlayWrightBrowserToolkit(BaseToolkit): extra = Extra.forbid arbitrary_types_allowed = True - @root_validator + @root_validator(pre=True) def validate_imports_and_browser_provided(cls, values: dict) -> dict: """Check that the arguments are valid.""" lazy_import_playwright_browsers() diff --git a/libs/community/langchain_community/document_loaders/docugami.py b/libs/community/langchain_community/document_loaders/docugami.py index fbbda405f90..7082b219652 100644 --- a/libs/community/langchain_community/document_loaders/docugami.py +++ b/libs/community/langchain_community/document_loaders/docugami.py @@ -81,7 +81,7 @@ class DocugamiLoader(BaseLoader, BaseModel): include_project_metadata_in_doc_metadata: bool = True """Set to True if you want to include the project metadata in the doc metadata.""" - @root_validator + @root_validator(pre=True) def validate_local_or_remote(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Validate that either local file paths are given, or remote API docset ID. diff --git a/libs/community/langchain_community/document_loaders/dropbox.py b/libs/community/langchain_community/document_loaders/dropbox.py index fb76043c138..f2a7b603cef 100644 --- a/libs/community/langchain_community/document_loaders/dropbox.py +++ b/libs/community/langchain_community/document_loaders/dropbox.py @@ -33,7 +33,7 @@ class DropboxLoader(BaseLoader, BaseModel): recursive: bool = False """Flag to indicate whether to load files recursively from subfolders.""" - @root_validator + @root_validator(pre=True) def validate_inputs(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Validate that either folder_path or file_paths is set, but not both.""" if ( diff --git a/libs/community/langchain_community/document_loaders/googledrive.py b/libs/community/langchain_community/document_loaders/googledrive.py index 4e51b71cc63..b7a072059e6 100644 --- a/libs/community/langchain_community/document_loaders/googledrive.py +++ b/libs/community/langchain_community/document_loaders/googledrive.py @@ -53,7 +53,7 @@ class GoogleDriveLoader(BaseLoader, BaseModel): file_loader_kwargs: Dict["str", Any] = {} """The file loader kwargs to use.""" - @root_validator + @root_validator(pre=True) def validate_inputs(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Validate that either folder_id or document_ids is set, but not both.""" if values.get("folder_id") and ( diff --git a/libs/community/langchain_community/document_loaders/youtube.py b/libs/community/langchain_community/document_loaders/youtube.py index 3c16a0706a4..ccab3f7227e 100644 --- a/libs/community/langchain_community/document_loaders/youtube.py +++ b/libs/community/langchain_community/document_loaders/youtube.py @@ -47,7 +47,7 @@ class GoogleApiClient: def __post_init__(self) -> None: self.creds = self._load_credentials() - @root_validator + @root_validator(pre=True) def validate_channel_or_videoIds_is_set( cls, values: Dict[str, Any] ) -> Dict[str, Any]: @@ -388,7 +388,7 @@ class GoogleApiYoutubeLoader(BaseLoader): return build("youtube", "v3", credentials=creds) - @root_validator + @root_validator(pre=True) def validate_channel_or_videoIds_is_set( cls, values: Dict[str, Any] ) -> Dict[str, Any]: diff --git a/libs/community/langchain_community/embeddings/ascend.py b/libs/community/langchain_community/embeddings/ascend.py index 4e71635663f..7512599bc53 100644 --- a/libs/community/langchain_community/embeddings/ascend.py +++ b/libs/community/langchain_community/embeddings/ascend.py @@ -54,8 +54,10 @@ class AscendEmbeddings(Embeddings, BaseModel): self.model.half() self.encode([f"warmup {i} times" for i in range(10)]) - @root_validator + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: + if "model_path" not in values: + raise ValueError("model_path is required") if not os.access(values["model_path"], os.F_OK): raise FileNotFoundError( f"Unabled to find valid model path in [{values['model_path']}]" diff --git a/libs/community/langchain_community/tools/edenai/audio_text_to_speech.py b/libs/community/langchain_community/tools/edenai/audio_text_to_speech.py index fad670e0f3f..421d06f5b00 100644 --- a/libs/community/langchain_community/tools/edenai/audio_text_to_speech.py +++ b/libs/community/langchain_community/tools/edenai/audio_text_to_speech.py @@ -65,7 +65,7 @@ class EdenAiTextToSpeechTool(EdenaiTool): ) return v - @root_validator + @root_validator(pre=True) def check_voice_models_key_is_provider_name(cls, values: dict) -> dict: for key in values.get("voice_models", {}).keys(): if key not in values.get("providers", []): diff --git a/libs/community/langchain_community/tools/playwright/base.py b/libs/community/langchain_community/tools/playwright/base.py index 5fab5a29026..f06b964e4d2 100644 --- a/libs/community/langchain_community/tools/playwright/base.py +++ b/libs/community/langchain_community/tools/playwright/base.py @@ -38,7 +38,7 @@ class BaseBrowserTool(BaseTool): sync_browser: Optional["SyncBrowser"] = None async_browser: Optional["AsyncBrowser"] = None - @root_validator + @root_validator(pre=True) def validate_browser_provided(cls, values: dict) -> dict: """Check that the arguments are valid.""" lazy_import_playwright_browsers() diff --git a/libs/community/langchain_community/tools/playwright/extract_hyperlinks.py b/libs/community/langchain_community/tools/playwright/extract_hyperlinks.py index 9c6f64911cc..3cac3c496b3 100644 --- a/libs/community/langchain_community/tools/playwright/extract_hyperlinks.py +++ b/libs/community/langchain_community/tools/playwright/extract_hyperlinks.py @@ -35,7 +35,7 @@ class ExtractHyperlinksTool(BaseBrowserTool): description: str = "Extract all hyperlinks on the current webpage" args_schema: Type[BaseModel] = ExtractHyperlinksToolInput - @root_validator + @root_validator(pre=True) def check_bs_import(cls, values: dict) -> dict: """Check that the arguments are valid.""" try: diff --git a/libs/community/langchain_community/tools/playwright/extract_text.py b/libs/community/langchain_community/tools/playwright/extract_text.py index 4f01ea925d0..0cf112b0fb2 100644 --- a/libs/community/langchain_community/tools/playwright/extract_text.py +++ b/libs/community/langchain_community/tools/playwright/extract_text.py @@ -22,7 +22,7 @@ class ExtractTextTool(BaseBrowserTool): description: str = "Extract all the text on the current webpage" args_schema: Type[BaseModel] = BaseModel - @root_validator + @root_validator(pre=True) def check_acheck_bs_importrgs(cls, values: dict) -> dict: """Check that the arguments are valid.""" try: diff --git a/libs/community/langchain_community/tools/shell/tool.py b/libs/community/langchain_community/tools/shell/tool.py index cb3e941ec63..cbc77ea16ff 100644 --- a/libs/community/langchain_community/tools/shell/tool.py +++ b/libs/community/langchain_community/tools/shell/tool.py @@ -21,7 +21,7 @@ class ShellInput(BaseModel): ) """List of shell commands to run.""" - @root_validator + @root_validator(pre=True) def _validate_commands(cls, values: dict) -> dict: """Validate commands.""" # TODO: Add real validators diff --git a/libs/community/langchain_community/tools/zapier/tool.py b/libs/community/langchain_community/tools/zapier/tool.py index e5bb3c36599..7cc97132a7c 100644 --- a/libs/community/langchain_community/tools/zapier/tool.py +++ b/libs/community/langchain_community/tools/zapier/tool.py @@ -75,8 +75,9 @@ from langchain_core.callbacks import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.pydantic_v1 import Field from langchain_core.tools import BaseTool +from langchain_core.utils import pre_init from langchain_community.tools.zapier.prompt import BASE_ZAPIER_TOOL_PROMPT from langchain_community.utilities.zapier import ZapierNLAWrapper @@ -105,7 +106,7 @@ class ZapierNLARunAction(BaseTool): name: str = "" description: str = "" - @root_validator + @pre_init def set_name_description(cls, values: Dict[str, Any]) -> Dict[str, Any]: zapier_description = values["zapier_description"] params_schema = values["params_schema"] diff --git a/libs/community/langchain_community/utilities/steam.py b/libs/community/langchain_community/utilities/steam.py index 778c3c6870e..1110382ae60 100644 --- a/libs/community/langchain_community/utilities/steam.py +++ b/libs/community/langchain_community/utilities/steam.py @@ -39,7 +39,7 @@ class SteamWebAPIWrapper(BaseModel): """Return a list of operations.""" return self.operations - @root_validator + @root_validator(pre=True) def validate_environment(cls, values: dict) -> dict: """Validate api key and python package has been configured.""" diff --git a/libs/community/langchain_community/utilities/you.py b/libs/community/langchain_community/utilities/you.py index 98cca924727..1cd17bdc556 100644 --- a/libs/community/langchain_community/utilities/you.py +++ b/libs/community/langchain_community/utilities/you.py @@ -114,7 +114,7 @@ class YouSearchAPIWrapper(BaseModel): return values - @root_validator + @root_validator(pre=False, skip_on_failure=True) def warn_if_set_fields_have_no_effect(cls, values: Dict) -> Dict: if values["endpoint_type"] != "news": news_api_fields = ("search_lang", "ui_lang", "spellcheck") @@ -139,7 +139,7 @@ class YouSearchAPIWrapper(BaseModel): ) return values - @root_validator + @root_validator(pre=False, skip_on_failure=True) def warn_if_deprecated_endpoints_are_used(cls, values: Dict) -> Dict: if values["endpoint_type"] == "snippets": warnings.warn( diff --git a/libs/community/scripts/check_pydantic.sh b/libs/community/scripts/check_pydantic.sh index 06b5bb81ae2..518338d377d 100755 --- a/libs/community/scripts/check_pydantic.sh +++ b/libs/community/scripts/check_pydantic.sh @@ -14,7 +14,7 @@ fi repository_path="$1" # Search for lines matching the pattern within the specified repository -result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic') +result=$(git -C "$repository_path" grep -En '^import pydantic|^from pydantic') # Check if any matching lines were found if [ -n "$result" ]; then @@ -25,3 +25,20 @@ if [ -n "$result" ]; then echo "with 'from langchain_core.pydantic_v1 import BaseModel'" exit 1 fi + +# Forbid vanilla usage of @root_validator +# This prevents the code from using either @root_validator or @root_validator() +# Search for lines matching the pattern within the specified repository +result=$(git -C "$repository_path" grep -En '(@root_validator\s*$)|(@root_validator\(\))' -- '*.py') + +# Check if any matching lines were found +if [ -n "$result" ]; then + echo "ERROR: The following lines need to be updated:" + echo + echo "$result" + echo + echo "Please replace @root_validator or @root_validator() with either:" + echo + echo "@root_validator(pre=True) or @root_validator(pre=False, skip_on_failure=True)" + exit 1 +fi