community[patch]: fix community warnings 1 (#26239)

This commit is contained in:
Bagatur
2024-09-09 17:27:00 -07:00
committed by GitHub
parent 438301db90
commit f2f9187919
7 changed files with 53 additions and 43 deletions

View File

@@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union
import requests import requests
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.utils import get_from_dict_or_env from langchain_core.utils import get_from_dict_or_env
from pydantic import BaseModel, model_validator, validator from pydantic import BaseModel, field_validator, model_validator
from langchain_community.document_loaders.base import BaseLoader from langchain_community.document_loaders.base import BaseLoader
@@ -73,7 +73,8 @@ class GitHubIssuesLoader(BaseGitHubLoader):
"""Number of items per page. """Number of items per page.
Defaults to 30 in the GitHub API.""" Defaults to 30 in the GitHub API."""
@validator("since", allow_reuse=True) @field_validator("since")
@classmethod
def validate_since(cls, v: Optional[str]) -> Optional[str]: def validate_since(cls, v: Optional[str]) -> Optional[str]:
if v: if v:
try: try:

View File

@@ -13,7 +13,6 @@ from pydantic import (
Field, Field,
PrivateAttr, PrivateAttr,
model_validator, model_validator,
validator,
) )
__all__ = ["Databricks"] __all__ = ["Databricks"]
@@ -414,18 +413,21 @@ class Databricks(LLM):
params["max_tokens"] = self.max_tokens params["max_tokens"] = self.max_tokens
return params return params
@validator("cluster_id", always=True) @model_validator(mode="before")
def set_cluster_id(cls, v: Any, values: Dict[str, Any]) -> Optional[str]: @classmethod
if v and values["endpoint_name"]: def set_cluster_id(cls, values: Dict[str, Any]) -> dict:
cluster_id = values.get("cluster_id")
endpoint_name = values.get("endpoint_name")
if cluster_id and endpoint_name:
raise ValueError("Cannot set both endpoint_name and cluster_id.") raise ValueError("Cannot set both endpoint_name and cluster_id.")
elif values["endpoint_name"]: elif endpoint_name:
return None values["cluster_id"] = None
elif v: elif cluster_id:
return v pass
else: else:
try: try:
if v := get_repl_context().clusterId: if context_cluster_id := get_repl_context().clusterId:
return v values["cluster_id"] = context_cluster_id
raise ValueError("Context doesn't contain clusterId.") raise ValueError("Context doesn't contain clusterId.")
except Exception as e: except Exception as e:
raise ValueError( raise ValueError(
@@ -434,27 +436,28 @@ class Databricks(LLM):
f" error: {e}" f" error: {e}"
) )
@validator("cluster_driver_port", always=True) cluster_driver_port = values.get("cluster_driver_port")
def set_cluster_driver_port(cls, v: Any, values: Dict[str, Any]) -> Optional[str]: if cluster_driver_port and endpoint_name:
if v and values["endpoint_name"]:
raise ValueError("Cannot set both endpoint_name and cluster_driver_port.") raise ValueError("Cannot set both endpoint_name and cluster_driver_port.")
elif values["endpoint_name"]: elif endpoint_name:
return None values["cluster_driver_port"] = None
elif v is None: elif cluster_driver_port is None:
raise ValueError( raise ValueError(
"Must set cluster_driver_port to connect to a cluster driver." "Must set cluster_driver_port to connect to a cluster driver."
) )
elif int(v) <= 0: elif int(cluster_driver_port) <= 0:
raise ValueError(f"Invalid cluster_driver_port: {v}") raise ValueError(f"Invalid cluster_driver_port: {cluster_driver_port}")
else: else:
return v pass
@validator("model_kwargs", always=True) if model_kwargs := values.get("model_kwargs"):
def set_model_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: assert (
if v: "prompt" not in model_kwargs
assert "prompt" not in v, "model_kwargs must not contain key 'prompt'" ), "model_kwargs must not contain key 'prompt'"
assert "stop" not in v, "model_kwargs must not contain key 'stop'" assert (
return v "stop" not in model_kwargs
), "model_kwargs must not contain key 'stop'"
return values
def __init__(self, **data: Any): def __init__(self, **data: Any):
if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]): if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]):

View File

@@ -385,6 +385,10 @@ class AmazonKendraRetriever(BaseRetriever):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def create_client(cls, values: Dict[str, Any]) -> Any: def create_client(cls, values: Dict[str, Any]) -> Any:
top_k = values.get("top_k")
if top_k is not None and top_k < 0:
raise ValueError(f"top_k ({top_k}) cannot be negative.")
if values.get("client") is not None: if values.get("client") is not None:
return values return values

View File

@@ -7,7 +7,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun, AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun, CallbackManagerForToolRun,
) )
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, model_validator
from langchain_community.tools.playwright.base import BaseBrowserTool from langchain_community.tools.playwright.base import BaseBrowserTool
from langchain_community.tools.playwright.utils import ( from langchain_community.tools.playwright.utils import (
@@ -21,13 +21,15 @@ class NavigateToolInput(BaseModel):
url: str = Field(..., description="url to navigate to") url: str = Field(..., description="url to navigate to")
@validator("url") @model_validator(mode="before")
def validate_url_scheme(cls, url: str) -> str: @classmethod
def validate_url_scheme(cls, values: dict) -> dict:
"""Check that the URL scheme is valid.""" """Check that the URL scheme is valid."""
url = values.get("url")
parsed_url = urlparse(url) parsed_url = urlparse(url)
if parsed_url.scheme not in ("http", "https"): if parsed_url.scheme not in ("http", "https"):
raise ValueError("URL scheme must be 'http' or 'https'") raise ValueError("URL scheme must be 'http' or 'https'")
return url return values
class NavigateTool(BaseBrowserTool): class NavigateTool(BaseBrowserTool):

View File

@@ -9,7 +9,7 @@ from langchain_core.callbacks import (
CallbackManagerForToolRun, CallbackManagerForToolRun,
) )
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from pydantic import ConfigDict, Field, validator from pydantic import ConfigDict, Field, model_validator
from langchain_community.chat_models.openai import _import_tiktoken from langchain_community.chat_models.openai import _import_tiktoken
from langchain_community.tools.powerbi.prompt import ( from langchain_community.tools.powerbi.prompt import (
@@ -43,18 +43,20 @@ class QueryPowerBITool(BaseTool):
arbitrary_types_allowed=True, arbitrary_types_allowed=True,
) )
@validator("llm_chain") @model_validator(mode="before")
@classmethod
def validate_llm_chain_input_variables( # pylint: disable=E0213 def validate_llm_chain_input_variables( # pylint: disable=E0213
cls, llm_chain: Any cls, values: dict
) -> Any: ) -> dict:
"""Make sure the LLM chain has the correct input variables.""" """Make sure the LLM chain has the correct input variables."""
llm_chain = values["llm_chain"]
for var in llm_chain.prompt.input_variables: for var in llm_chain.prompt.input_variables:
if var not in ["tool_input", "tables", "schemas", "examples"]: if var not in ["tool_input", "tables", "schemas", "examples"]:
raise ValueError( raise ValueError(
"LLM chain for QueryPowerBITool must have input variables ['tool_input', 'tables', 'schemas', 'examples'], found %s", # noqa: E501 # pylint: disable=C0301 "LLM chain for QueryPowerBITool must have input variables ['tool_input', 'tables', 'schemas', 'examples'], found %s", # noqa: E501 # pylint: disable=C0301
llm_chain.prompt.input_variables, llm_chain.prompt.input_variables,
) )
return llm_chain return values
def _check_cache(self, tool_input: str) -> Optional[str]: def _check_cache(self, tool_input: str) -> Optional[str]:
"""Check if the input is present in the cache. """Check if the input is present in the cache.

View File

@@ -15,7 +15,6 @@ from pydantic import (
ConfigDict, ConfigDict,
Field, Field,
model_validator, model_validator,
validator,
) )
from requests.exceptions import Timeout from requests.exceptions import Timeout
@@ -50,15 +49,12 @@ class PowerBIDataset(BaseModel):
arbitrary_types_allowed=True, arbitrary_types_allowed=True,
) )
@validator("table_names", allow_reuse=True)
def fix_table_names(cls, table_names: List[str]) -> List[str]:
"""Fix the table names."""
return [fix_table_name(table) for table in table_names]
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def token_or_credential_present(cls, values: Dict[str, Any]) -> Any: def validate_params(cls, values: Dict[str, Any]) -> Any:
"""Validate that at least one of token and credentials is present.""" """Validate that at least one of token and credentials is present."""
table_names = values.get("table_names", [])
values["table_names"] = [fix_table_name(table) for table in table_names]
if "token" in values or "credential" in values: if "token" in values or "credential" in values:
return values return values
raise ValueError("Please provide either a credential or a token.") raise ValueError("Please provide either a credential or a token.")

View File

@@ -66,7 +66,9 @@ markers = [
] ]
asyncio_mode = "auto" asyncio_mode = "auto"
filterwarnings = [ filterwarnings = [
"ignore::langchain_core._api.beta_decorator.LangChainBetaWarning",
"ignore::langchain_core._api.deprecation.LangChainDeprecationWarning:test", "ignore::langchain_core._api.deprecation.LangChainDeprecationWarning:test",
"ignore::langchain_core._api.deprecation.LangChainPendingDeprecationWarning:test",
] ]
[tool.poetry.group.test] [tool.poetry.group.test]