diff --git a/libs/community/langchain_community/document_loaders/github.py b/libs/community/langchain_community/document_loaders/github.py index a92d854b7ea..94cbe0553d0 100644 --- a/libs/community/langchain_community/document_loaders/github.py +++ b/libs/community/langchain_community/document_loaders/github.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union import requests from langchain_core.documents import Document from langchain_core.utils import get_from_dict_or_env -from pydantic import BaseModel, model_validator, validator +from pydantic import BaseModel, field_validator, model_validator from langchain_community.document_loaders.base import BaseLoader @@ -73,7 +73,8 @@ class GitHubIssuesLoader(BaseGitHubLoader): """Number of items per page. Defaults to 30 in the GitHub API.""" - @validator("since", allow_reuse=True) + @field_validator("since") + @classmethod def validate_since(cls, v: Optional[str]) -> Optional[str]: if v: try: diff --git a/libs/community/langchain_community/llms/databricks.py b/libs/community/langchain_community/llms/databricks.py index 6d774d38ca2..9f535d7373d 100644 --- a/libs/community/langchain_community/llms/databricks.py +++ b/libs/community/langchain_community/llms/databricks.py @@ -13,7 +13,6 @@ from pydantic import ( Field, PrivateAttr, model_validator, - validator, ) __all__ = ["Databricks"] @@ -414,18 +413,21 @@ class Databricks(LLM): params["max_tokens"] = self.max_tokens return params - @validator("cluster_id", always=True) - def set_cluster_id(cls, v: Any, values: Dict[str, Any]) -> Optional[str]: - if v and values["endpoint_name"]: + @model_validator(mode="before") + @classmethod + def set_cluster_id(cls, values: Dict[str, Any]) -> dict: + cluster_id = values.get("cluster_id") + endpoint_name = values.get("endpoint_name") + if cluster_id and endpoint_name: raise ValueError("Cannot set both endpoint_name and cluster_id.") - elif values["endpoint_name"]: - return None - elif v: - return v + elif endpoint_name: + values["cluster_id"] = None + elif cluster_id: + pass else: try: - if v := get_repl_context().clusterId: - return v + if context_cluster_id := get_repl_context().clusterId: + values["cluster_id"] = context_cluster_id raise ValueError("Context doesn't contain clusterId.") except Exception as e: raise ValueError( @@ -434,27 +436,28 @@ class Databricks(LLM): f" error: {e}" ) - @validator("cluster_driver_port", always=True) - def set_cluster_driver_port(cls, v: Any, values: Dict[str, Any]) -> Optional[str]: - if v and values["endpoint_name"]: + cluster_driver_port = values.get("cluster_driver_port") + if cluster_driver_port and endpoint_name: raise ValueError("Cannot set both endpoint_name and cluster_driver_port.") - elif values["endpoint_name"]: - return None - elif v is None: + elif endpoint_name: + values["cluster_driver_port"] = None + elif cluster_driver_port is None: raise ValueError( "Must set cluster_driver_port to connect to a cluster driver." ) - elif int(v) <= 0: - raise ValueError(f"Invalid cluster_driver_port: {v}") + elif int(cluster_driver_port) <= 0: + raise ValueError(f"Invalid cluster_driver_port: {cluster_driver_port}") else: - return v + pass - @validator("model_kwargs", always=True) - def set_model_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: - if v: - assert "prompt" not in v, "model_kwargs must not contain key 'prompt'" - assert "stop" not in v, "model_kwargs must not contain key 'stop'" - return v + if model_kwargs := values.get("model_kwargs"): + assert ( + "prompt" not in model_kwargs + ), "model_kwargs must not contain key 'prompt'" + assert ( + "stop" not in model_kwargs + ), "model_kwargs must not contain key 'stop'" + return values def __init__(self, **data: Any): if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]): diff --git a/libs/community/langchain_community/retrievers/kendra.py b/libs/community/langchain_community/retrievers/kendra.py index 7c824b0eacf..fe36e85dbe1 100644 --- a/libs/community/langchain_community/retrievers/kendra.py +++ b/libs/community/langchain_community/retrievers/kendra.py @@ -385,6 +385,10 @@ class AmazonKendraRetriever(BaseRetriever): @model_validator(mode="before") @classmethod def create_client(cls, values: Dict[str, Any]) -> Any: + top_k = values.get("top_k") + if top_k is not None and top_k < 0: + raise ValueError(f"top_k ({top_k}) cannot be negative.") + if values.get("client") is not None: return values diff --git a/libs/community/langchain_community/tools/playwright/navigate.py b/libs/community/langchain_community/tools/playwright/navigate.py index 5c0de271ce9..2bfe2be4fd7 100644 --- a/libs/community/langchain_community/tools/playwright/navigate.py +++ b/libs/community/langchain_community/tools/playwright/navigate.py @@ -7,7 +7,7 @@ from langchain_core.callbacks import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, model_validator from langchain_community.tools.playwright.base import BaseBrowserTool from langchain_community.tools.playwright.utils import ( @@ -21,13 +21,15 @@ class NavigateToolInput(BaseModel): url: str = Field(..., description="url to navigate to") - @validator("url") - def validate_url_scheme(cls, url: str) -> str: + @model_validator(mode="before") + @classmethod + def validate_url_scheme(cls, values: dict) -> dict: """Check that the URL scheme is valid.""" + url = values.get("url") parsed_url = urlparse(url) if parsed_url.scheme not in ("http", "https"): raise ValueError("URL scheme must be 'http' or 'https'") - return url + return values class NavigateTool(BaseBrowserTool): diff --git a/libs/community/langchain_community/tools/powerbi/tool.py b/libs/community/langchain_community/tools/powerbi/tool.py index 9d55f431c2f..59794294a5d 100644 --- a/libs/community/langchain_community/tools/powerbi/tool.py +++ b/libs/community/langchain_community/tools/powerbi/tool.py @@ -9,7 +9,7 @@ from langchain_core.callbacks import ( CallbackManagerForToolRun, ) from langchain_core.tools import BaseTool -from pydantic import ConfigDict, Field, validator +from pydantic import ConfigDict, Field, model_validator from langchain_community.chat_models.openai import _import_tiktoken from langchain_community.tools.powerbi.prompt import ( @@ -43,18 +43,20 @@ class QueryPowerBITool(BaseTool): arbitrary_types_allowed=True, ) - @validator("llm_chain") + @model_validator(mode="before") + @classmethod def validate_llm_chain_input_variables( # pylint: disable=E0213 - cls, llm_chain: Any - ) -> Any: + cls, values: dict + ) -> dict: """Make sure the LLM chain has the correct input variables.""" + llm_chain = values["llm_chain"] for var in llm_chain.prompt.input_variables: if var not in ["tool_input", "tables", "schemas", "examples"]: raise ValueError( "LLM chain for QueryPowerBITool must have input variables ['tool_input', 'tables', 'schemas', 'examples'], found %s", # noqa: E501 # pylint: disable=C0301 llm_chain.prompt.input_variables, ) - return llm_chain + return values def _check_cache(self, tool_input: str) -> Optional[str]: """Check if the input is present in the cache. diff --git a/libs/community/langchain_community/utilities/powerbi.py b/libs/community/langchain_community/utilities/powerbi.py index 88aee2ae087..7c3c1a1eaa6 100644 --- a/libs/community/langchain_community/utilities/powerbi.py +++ b/libs/community/langchain_community/utilities/powerbi.py @@ -15,7 +15,6 @@ from pydantic import ( ConfigDict, Field, model_validator, - validator, ) from requests.exceptions import Timeout @@ -50,15 +49,12 @@ class PowerBIDataset(BaseModel): arbitrary_types_allowed=True, ) - @validator("table_names", allow_reuse=True) - def fix_table_names(cls, table_names: List[str]) -> List[str]: - """Fix the table names.""" - return [fix_table_name(table) for table in table_names] - @model_validator(mode="before") @classmethod - def token_or_credential_present(cls, values: Dict[str, Any]) -> Any: + def validate_params(cls, values: Dict[str, Any]) -> Any: """Validate that at least one of token and credentials is present.""" + table_names = values.get("table_names", []) + values["table_names"] = [fix_table_name(table) for table in table_names] if "token" in values or "credential" in values: return values raise ValueError("Please provide either a credential or a token.") diff --git a/libs/community/pyproject.toml b/libs/community/pyproject.toml index 75422be1302..6d4ef5d828c 100644 --- a/libs/community/pyproject.toml +++ b/libs/community/pyproject.toml @@ -66,7 +66,9 @@ markers = [ ] asyncio_mode = "auto" filterwarnings = [ + "ignore::langchain_core._api.beta_decorator.LangChainBetaWarning", "ignore::langchain_core._api.deprecation.LangChainDeprecationWarning:test", + "ignore::langchain_core._api.deprecation.LangChainPendingDeprecationWarning:test", ] [tool.poetry.group.test]