community[patch]: Add linter to catch @root_validator (#24070)

- Add linter to prevent further usage of vanilla root validator
- Udpate remaining root validators
This commit is contained in:
Eugene Yurtsev 2024-07-10 10:51:03 -04:00 committed by GitHub
parent 9c6efadec3
commit c4e149d4f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 39 additions and 19 deletions

View File

@ -63,7 +63,7 @@ class FileManagementToolkit(BaseToolkit):
selected_tools: Optional[List[str]] = None selected_tools: Optional[List[str]] = None
"""If provided, only provide the selected tools. Defaults to all.""" """If provided, only provide the selected tools. Defaults to all."""
@root_validator @root_validator(pre=True)
def validate_tools(cls, values: dict) -> dict: def validate_tools(cls, values: dict) -> dict:
selected_tools = values.get("selected_tools") or [] selected_tools = values.get("selected_tools") or []
for tool_name in selected_tools: for tool_name in selected_tools:

View File

@ -74,7 +74,7 @@ class PlayWrightBrowserToolkit(BaseToolkit):
extra = Extra.forbid extra = Extra.forbid
arbitrary_types_allowed = True arbitrary_types_allowed = True
@root_validator @root_validator(pre=True)
def validate_imports_and_browser_provided(cls, values: dict) -> dict: def validate_imports_and_browser_provided(cls, values: dict) -> dict:
"""Check that the arguments are valid.""" """Check that the arguments are valid."""
lazy_import_playwright_browsers() lazy_import_playwright_browsers()

View File

@ -81,7 +81,7 @@ class DocugamiLoader(BaseLoader, BaseModel):
include_project_metadata_in_doc_metadata: bool = True include_project_metadata_in_doc_metadata: bool = True
"""Set to True if you want to include the project metadata in the doc metadata.""" """Set to True if you want to include the project metadata in the doc metadata."""
@root_validator @root_validator(pre=True)
def validate_local_or_remote(cls, values: Dict[str, Any]) -> Dict[str, Any]: def validate_local_or_remote(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate that either local file paths are given, or remote API docset ID. """Validate that either local file paths are given, or remote API docset ID.

View File

@ -33,7 +33,7 @@ class DropboxLoader(BaseLoader, BaseModel):
recursive: bool = False recursive: bool = False
"""Flag to indicate whether to load files recursively from subfolders.""" """Flag to indicate whether to load files recursively from subfolders."""
@root_validator @root_validator(pre=True)
def validate_inputs(cls, values: Dict[str, Any]) -> Dict[str, Any]: def validate_inputs(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate that either folder_path or file_paths is set, but not both.""" """Validate that either folder_path or file_paths is set, but not both."""
if ( if (

View File

@ -53,7 +53,7 @@ class GoogleDriveLoader(BaseLoader, BaseModel):
file_loader_kwargs: Dict["str", Any] = {} file_loader_kwargs: Dict["str", Any] = {}
"""The file loader kwargs to use.""" """The file loader kwargs to use."""
@root_validator @root_validator(pre=True)
def validate_inputs(cls, values: Dict[str, Any]) -> Dict[str, Any]: def validate_inputs(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate that either folder_id or document_ids is set, but not both.""" """Validate that either folder_id or document_ids is set, but not both."""
if values.get("folder_id") and ( if values.get("folder_id") and (

View File

@ -47,7 +47,7 @@ class GoogleApiClient:
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.creds = self._load_credentials() self.creds = self._load_credentials()
@root_validator @root_validator(pre=True)
def validate_channel_or_videoIds_is_set( def validate_channel_or_videoIds_is_set(
cls, values: Dict[str, Any] cls, values: Dict[str, Any]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@ -388,7 +388,7 @@ class GoogleApiYoutubeLoader(BaseLoader):
return build("youtube", "v3", credentials=creds) return build("youtube", "v3", credentials=creds)
@root_validator @root_validator(pre=True)
def validate_channel_or_videoIds_is_set( def validate_channel_or_videoIds_is_set(
cls, values: Dict[str, Any] cls, values: Dict[str, Any]
) -> Dict[str, Any]: ) -> Dict[str, Any]:

View File

@ -54,8 +54,10 @@ class AscendEmbeddings(Embeddings, BaseModel):
self.model.half() self.model.half()
self.encode([f"warmup {i} times" for i in range(10)]) self.encode([f"warmup {i} times" for i in range(10)])
@root_validator @root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
if "model_path" not in values:
raise ValueError("model_path is required")
if not os.access(values["model_path"], os.F_OK): if not os.access(values["model_path"], os.F_OK):
raise FileNotFoundError( raise FileNotFoundError(
f"Unabled to find valid model path in [{values['model_path']}]" f"Unabled to find valid model path in [{values['model_path']}]"

View File

@ -65,7 +65,7 @@ class EdenAiTextToSpeechTool(EdenaiTool):
) )
return v return v
@root_validator @root_validator(pre=True)
def check_voice_models_key_is_provider_name(cls, values: dict) -> dict: def check_voice_models_key_is_provider_name(cls, values: dict) -> dict:
for key in values.get("voice_models", {}).keys(): for key in values.get("voice_models", {}).keys():
if key not in values.get("providers", []): if key not in values.get("providers", []):

View File

@ -38,7 +38,7 @@ class BaseBrowserTool(BaseTool):
sync_browser: Optional["SyncBrowser"] = None sync_browser: Optional["SyncBrowser"] = None
async_browser: Optional["AsyncBrowser"] = None async_browser: Optional["AsyncBrowser"] = None
@root_validator @root_validator(pre=True)
def validate_browser_provided(cls, values: dict) -> dict: def validate_browser_provided(cls, values: dict) -> dict:
"""Check that the arguments are valid.""" """Check that the arguments are valid."""
lazy_import_playwright_browsers() lazy_import_playwright_browsers()

View File

@ -35,7 +35,7 @@ class ExtractHyperlinksTool(BaseBrowserTool):
description: str = "Extract all hyperlinks on the current webpage" description: str = "Extract all hyperlinks on the current webpage"
args_schema: Type[BaseModel] = ExtractHyperlinksToolInput args_schema: Type[BaseModel] = ExtractHyperlinksToolInput
@root_validator @root_validator(pre=True)
def check_bs_import(cls, values: dict) -> dict: def check_bs_import(cls, values: dict) -> dict:
"""Check that the arguments are valid.""" """Check that the arguments are valid."""
try: try:

View File

@ -22,7 +22,7 @@ class ExtractTextTool(BaseBrowserTool):
description: str = "Extract all the text on the current webpage" description: str = "Extract all the text on the current webpage"
args_schema: Type[BaseModel] = BaseModel args_schema: Type[BaseModel] = BaseModel
@root_validator @root_validator(pre=True)
def check_acheck_bs_importrgs(cls, values: dict) -> dict: def check_acheck_bs_importrgs(cls, values: dict) -> dict:
"""Check that the arguments are valid.""" """Check that the arguments are valid."""
try: try:

View File

@ -21,7 +21,7 @@ class ShellInput(BaseModel):
) )
"""List of shell commands to run.""" """List of shell commands to run."""
@root_validator @root_validator(pre=True)
def _validate_commands(cls, values: dict) -> dict: def _validate_commands(cls, values: dict) -> dict:
"""Validate commands.""" """Validate commands."""
# TODO: Add real validators # TODO: Add real validators

View File

@ -75,8 +75,9 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun, AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun, CallbackManagerForToolRun,
) )
from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.pydantic_v1 import Field
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils import pre_init
from langchain_community.tools.zapier.prompt import BASE_ZAPIER_TOOL_PROMPT from langchain_community.tools.zapier.prompt import BASE_ZAPIER_TOOL_PROMPT
from langchain_community.utilities.zapier import ZapierNLAWrapper from langchain_community.utilities.zapier import ZapierNLAWrapper
@ -105,7 +106,7 @@ class ZapierNLARunAction(BaseTool):
name: str = "" name: str = ""
description: str = "" description: str = ""
@root_validator @pre_init
def set_name_description(cls, values: Dict[str, Any]) -> Dict[str, Any]: def set_name_description(cls, values: Dict[str, Any]) -> Dict[str, Any]:
zapier_description = values["zapier_description"] zapier_description = values["zapier_description"]
params_schema = values["params_schema"] params_schema = values["params_schema"]

View File

@ -39,7 +39,7 @@ class SteamWebAPIWrapper(BaseModel):
"""Return a list of operations.""" """Return a list of operations."""
return self.operations return self.operations
@root_validator @root_validator(pre=True)
def validate_environment(cls, values: dict) -> dict: def validate_environment(cls, values: dict) -> dict:
"""Validate api key and python package has been configured.""" """Validate api key and python package has been configured."""

View File

@ -114,7 +114,7 @@ class YouSearchAPIWrapper(BaseModel):
return values return values
@root_validator @root_validator(pre=False, skip_on_failure=True)
def warn_if_set_fields_have_no_effect(cls, values: Dict) -> Dict: def warn_if_set_fields_have_no_effect(cls, values: Dict) -> Dict:
if values["endpoint_type"] != "news": if values["endpoint_type"] != "news":
news_api_fields = ("search_lang", "ui_lang", "spellcheck") news_api_fields = ("search_lang", "ui_lang", "spellcheck")
@ -139,7 +139,7 @@ class YouSearchAPIWrapper(BaseModel):
) )
return values return values
@root_validator @root_validator(pre=False, skip_on_failure=True)
def warn_if_deprecated_endpoints_are_used(cls, values: Dict) -> Dict: def warn_if_deprecated_endpoints_are_used(cls, values: Dict) -> Dict:
if values["endpoint_type"] == "snippets": if values["endpoint_type"] == "snippets":
warnings.warn( warnings.warn(

View File

@ -14,7 +14,7 @@ fi
repository_path="$1" repository_path="$1"
# Search for lines matching the pattern within the specified repository # Search for lines matching the pattern within the specified repository
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic') result=$(git -C "$repository_path" grep -En '^import pydantic|^from pydantic')
# Check if any matching lines were found # Check if any matching lines were found
if [ -n "$result" ]; then if [ -n "$result" ]; then
@ -25,3 +25,20 @@ if [ -n "$result" ]; then
echo "with 'from langchain_core.pydantic_v1 import BaseModel'" echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
exit 1 exit 1
fi fi
# Forbid vanilla usage of @root_validator
# This prevents the code from using either @root_validator or @root_validator()
# Search for lines matching the pattern within the specified repository
result=$(git -C "$repository_path" grep -En '(@root_validator\s*$)|(@root_validator\(\))' -- '*.py')
# Check if any matching lines were found
if [ -n "$result" ]; then
echo "ERROR: The following lines need to be updated:"
echo
echo "$result"
echo
echo "Please replace @root_validator or @root_validator() with either:"
echo
echo "@root_validator(pre=True) or @root_validator(pre=False, skip_on_failure=True)"
exit 1
fi