mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 10:43:36 +00:00
community[patch]: Add linter to catch @root_validator (#24070)
- Add linter to prevent further usage of vanilla root validator - Udpate remaining root validators
This commit is contained in:
parent
9c6efadec3
commit
c4e149d4f1
@ -63,7 +63,7 @@ class FileManagementToolkit(BaseToolkit):
|
|||||||
selected_tools: Optional[List[str]] = None
|
selected_tools: Optional[List[str]] = None
|
||||||
"""If provided, only provide the selected tools. Defaults to all."""
|
"""If provided, only provide the selected tools. Defaults to all."""
|
||||||
|
|
||||||
@root_validator
|
@root_validator(pre=True)
|
||||||
def validate_tools(cls, values: dict) -> dict:
|
def validate_tools(cls, values: dict) -> dict:
|
||||||
selected_tools = values.get("selected_tools") or []
|
selected_tools = values.get("selected_tools") or []
|
||||||
for tool_name in selected_tools:
|
for tool_name in selected_tools:
|
||||||
|
@ -74,7 +74,7 @@ class PlayWrightBrowserToolkit(BaseToolkit):
|
|||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@root_validator
|
@root_validator(pre=True)
|
||||||
def validate_imports_and_browser_provided(cls, values: dict) -> dict:
|
def validate_imports_and_browser_provided(cls, values: dict) -> dict:
|
||||||
"""Check that the arguments are valid."""
|
"""Check that the arguments are valid."""
|
||||||
lazy_import_playwright_browsers()
|
lazy_import_playwright_browsers()
|
||||||
|
@ -81,7 +81,7 @@ class DocugamiLoader(BaseLoader, BaseModel):
|
|||||||
include_project_metadata_in_doc_metadata: bool = True
|
include_project_metadata_in_doc_metadata: bool = True
|
||||||
"""Set to True if you want to include the project metadata in the doc metadata."""
|
"""Set to True if you want to include the project metadata in the doc metadata."""
|
||||||
|
|
||||||
@root_validator
|
@root_validator(pre=True)
|
||||||
def validate_local_or_remote(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def validate_local_or_remote(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Validate that either local file paths are given, or remote API docset ID.
|
"""Validate that either local file paths are given, or remote API docset ID.
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ class DropboxLoader(BaseLoader, BaseModel):
|
|||||||
recursive: bool = False
|
recursive: bool = False
|
||||||
"""Flag to indicate whether to load files recursively from subfolders."""
|
"""Flag to indicate whether to load files recursively from subfolders."""
|
||||||
|
|
||||||
@root_validator
|
@root_validator(pre=True)
|
||||||
def validate_inputs(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def validate_inputs(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Validate that either folder_path or file_paths is set, but not both."""
|
"""Validate that either folder_path or file_paths is set, but not both."""
|
||||||
if (
|
if (
|
||||||
|
@ -53,7 +53,7 @@ class GoogleDriveLoader(BaseLoader, BaseModel):
|
|||||||
file_loader_kwargs: Dict["str", Any] = {}
|
file_loader_kwargs: Dict["str", Any] = {}
|
||||||
"""The file loader kwargs to use."""
|
"""The file loader kwargs to use."""
|
||||||
|
|
||||||
@root_validator
|
@root_validator(pre=True)
|
||||||
def validate_inputs(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def validate_inputs(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Validate that either folder_id or document_ids is set, but not both."""
|
"""Validate that either folder_id or document_ids is set, but not both."""
|
||||||
if values.get("folder_id") and (
|
if values.get("folder_id") and (
|
||||||
|
@ -47,7 +47,7 @@ class GoogleApiClient:
|
|||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
self.creds = self._load_credentials()
|
self.creds = self._load_credentials()
|
||||||
|
|
||||||
@root_validator
|
@root_validator(pre=True)
|
||||||
def validate_channel_or_videoIds_is_set(
|
def validate_channel_or_videoIds_is_set(
|
||||||
cls, values: Dict[str, Any]
|
cls, values: Dict[str, Any]
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
@ -388,7 +388,7 @@ class GoogleApiYoutubeLoader(BaseLoader):
|
|||||||
|
|
||||||
return build("youtube", "v3", credentials=creds)
|
return build("youtube", "v3", credentials=creds)
|
||||||
|
|
||||||
@root_validator
|
@root_validator(pre=True)
|
||||||
def validate_channel_or_videoIds_is_set(
|
def validate_channel_or_videoIds_is_set(
|
||||||
cls, values: Dict[str, Any]
|
cls, values: Dict[str, Any]
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
|
@ -54,8 +54,10 @@ class AscendEmbeddings(Embeddings, BaseModel):
|
|||||||
self.model.half()
|
self.model.half()
|
||||||
self.encode([f"warmup {i} times" for i in range(10)])
|
self.encode([f"warmup {i} times" for i in range(10)])
|
||||||
|
|
||||||
@root_validator
|
@root_validator(pre=True)
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
if "model_path" not in values:
|
||||||
|
raise ValueError("model_path is required")
|
||||||
if not os.access(values["model_path"], os.F_OK):
|
if not os.access(values["model_path"], os.F_OK):
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"Unabled to find valid model path in [{values['model_path']}]"
|
f"Unabled to find valid model path in [{values['model_path']}]"
|
||||||
|
@ -65,7 +65,7 @@ class EdenAiTextToSpeechTool(EdenaiTool):
|
|||||||
)
|
)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@root_validator
|
@root_validator(pre=True)
|
||||||
def check_voice_models_key_is_provider_name(cls, values: dict) -> dict:
|
def check_voice_models_key_is_provider_name(cls, values: dict) -> dict:
|
||||||
for key in values.get("voice_models", {}).keys():
|
for key in values.get("voice_models", {}).keys():
|
||||||
if key not in values.get("providers", []):
|
if key not in values.get("providers", []):
|
||||||
|
@ -38,7 +38,7 @@ class BaseBrowserTool(BaseTool):
|
|||||||
sync_browser: Optional["SyncBrowser"] = None
|
sync_browser: Optional["SyncBrowser"] = None
|
||||||
async_browser: Optional["AsyncBrowser"] = None
|
async_browser: Optional["AsyncBrowser"] = None
|
||||||
|
|
||||||
@root_validator
|
@root_validator(pre=True)
|
||||||
def validate_browser_provided(cls, values: dict) -> dict:
|
def validate_browser_provided(cls, values: dict) -> dict:
|
||||||
"""Check that the arguments are valid."""
|
"""Check that the arguments are valid."""
|
||||||
lazy_import_playwright_browsers()
|
lazy_import_playwright_browsers()
|
||||||
|
@ -35,7 +35,7 @@ class ExtractHyperlinksTool(BaseBrowserTool):
|
|||||||
description: str = "Extract all hyperlinks on the current webpage"
|
description: str = "Extract all hyperlinks on the current webpage"
|
||||||
args_schema: Type[BaseModel] = ExtractHyperlinksToolInput
|
args_schema: Type[BaseModel] = ExtractHyperlinksToolInput
|
||||||
|
|
||||||
@root_validator
|
@root_validator(pre=True)
|
||||||
def check_bs_import(cls, values: dict) -> dict:
|
def check_bs_import(cls, values: dict) -> dict:
|
||||||
"""Check that the arguments are valid."""
|
"""Check that the arguments are valid."""
|
||||||
try:
|
try:
|
||||||
|
@ -22,7 +22,7 @@ class ExtractTextTool(BaseBrowserTool):
|
|||||||
description: str = "Extract all the text on the current webpage"
|
description: str = "Extract all the text on the current webpage"
|
||||||
args_schema: Type[BaseModel] = BaseModel
|
args_schema: Type[BaseModel] = BaseModel
|
||||||
|
|
||||||
@root_validator
|
@root_validator(pre=True)
|
||||||
def check_acheck_bs_importrgs(cls, values: dict) -> dict:
|
def check_acheck_bs_importrgs(cls, values: dict) -> dict:
|
||||||
"""Check that the arguments are valid."""
|
"""Check that the arguments are valid."""
|
||||||
try:
|
try:
|
||||||
|
@ -21,7 +21,7 @@ class ShellInput(BaseModel):
|
|||||||
)
|
)
|
||||||
"""List of shell commands to run."""
|
"""List of shell commands to run."""
|
||||||
|
|
||||||
@root_validator
|
@root_validator(pre=True)
|
||||||
def _validate_commands(cls, values: dict) -> dict:
|
def _validate_commands(cls, values: dict) -> dict:
|
||||||
"""Validate commands."""
|
"""Validate commands."""
|
||||||
# TODO: Add real validators
|
# TODO: Add real validators
|
||||||
|
@ -75,8 +75,9 @@ from langchain_core.callbacks import (
|
|||||||
AsyncCallbackManagerForToolRun,
|
AsyncCallbackManagerForToolRun,
|
||||||
CallbackManagerForToolRun,
|
CallbackManagerForToolRun,
|
||||||
)
|
)
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator
|
from langchain_core.pydantic_v1 import Field
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
from langchain_core.utils import pre_init
|
||||||
|
|
||||||
from langchain_community.tools.zapier.prompt import BASE_ZAPIER_TOOL_PROMPT
|
from langchain_community.tools.zapier.prompt import BASE_ZAPIER_TOOL_PROMPT
|
||||||
from langchain_community.utilities.zapier import ZapierNLAWrapper
|
from langchain_community.utilities.zapier import ZapierNLAWrapper
|
||||||
@ -105,7 +106,7 @@ class ZapierNLARunAction(BaseTool):
|
|||||||
name: str = ""
|
name: str = ""
|
||||||
description: str = ""
|
description: str = ""
|
||||||
|
|
||||||
@root_validator
|
@pre_init
|
||||||
def set_name_description(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def set_name_description(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
zapier_description = values["zapier_description"]
|
zapier_description = values["zapier_description"]
|
||||||
params_schema = values["params_schema"]
|
params_schema = values["params_schema"]
|
||||||
|
@ -39,7 +39,7 @@ class SteamWebAPIWrapper(BaseModel):
|
|||||||
"""Return a list of operations."""
|
"""Return a list of operations."""
|
||||||
return self.operations
|
return self.operations
|
||||||
|
|
||||||
@root_validator
|
@root_validator(pre=True)
|
||||||
def validate_environment(cls, values: dict) -> dict:
|
def validate_environment(cls, values: dict) -> dict:
|
||||||
"""Validate api key and python package has been configured."""
|
"""Validate api key and python package has been configured."""
|
||||||
|
|
||||||
|
@ -114,7 +114,7 @@ class YouSearchAPIWrapper(BaseModel):
|
|||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@root_validator
|
@root_validator(pre=False, skip_on_failure=True)
|
||||||
def warn_if_set_fields_have_no_effect(cls, values: Dict) -> Dict:
|
def warn_if_set_fields_have_no_effect(cls, values: Dict) -> Dict:
|
||||||
if values["endpoint_type"] != "news":
|
if values["endpoint_type"] != "news":
|
||||||
news_api_fields = ("search_lang", "ui_lang", "spellcheck")
|
news_api_fields = ("search_lang", "ui_lang", "spellcheck")
|
||||||
@ -139,7 +139,7 @@ class YouSearchAPIWrapper(BaseModel):
|
|||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@root_validator
|
@root_validator(pre=False, skip_on_failure=True)
|
||||||
def warn_if_deprecated_endpoints_are_used(cls, values: Dict) -> Dict:
|
def warn_if_deprecated_endpoints_are_used(cls, values: Dict) -> Dict:
|
||||||
if values["endpoint_type"] == "snippets":
|
if values["endpoint_type"] == "snippets":
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
@ -14,7 +14,7 @@ fi
|
|||||||
repository_path="$1"
|
repository_path="$1"
|
||||||
|
|
||||||
# Search for lines matching the pattern within the specified repository
|
# Search for lines matching the pattern within the specified repository
|
||||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
result=$(git -C "$repository_path" grep -En '^import pydantic|^from pydantic')
|
||||||
|
|
||||||
# Check if any matching lines were found
|
# Check if any matching lines were found
|
||||||
if [ -n "$result" ]; then
|
if [ -n "$result" ]; then
|
||||||
@ -25,3 +25,20 @@ if [ -n "$result" ]; then
|
|||||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Forbid vanilla usage of @root_validator
|
||||||
|
# This prevents the code from using either @root_validator or @root_validator()
|
||||||
|
# Search for lines matching the pattern within the specified repository
|
||||||
|
result=$(git -C "$repository_path" grep -En '(@root_validator\s*$)|(@root_validator\(\))' -- '*.py')
|
||||||
|
|
||||||
|
# Check if any matching lines were found
|
||||||
|
if [ -n "$result" ]; then
|
||||||
|
echo "ERROR: The following lines need to be updated:"
|
||||||
|
echo
|
||||||
|
echo "$result"
|
||||||
|
echo
|
||||||
|
echo "Please replace @root_validator or @root_validator() with either:"
|
||||||
|
echo
|
||||||
|
echo "@root_validator(pre=True) or @root_validator(pre=False, skip_on_failure=True)"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
Loading…
Reference in New Issue
Block a user