community[patch]: fix community warnings 1 (#26239)

This commit is contained in:
Bagatur
2024-09-09 17:27:00 -07:00
committed by GitHub
parent 438301db90
commit f2f9187919
7 changed files with 53 additions and 43 deletions

View File

@@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union
import requests
from langchain_core.documents import Document
from langchain_core.utils import get_from_dict_or_env
from pydantic import BaseModel, model_validator, validator
from pydantic import BaseModel, field_validator, model_validator
from langchain_community.document_loaders.base import BaseLoader
@@ -73,7 +73,8 @@ class GitHubIssuesLoader(BaseGitHubLoader):
"""Number of items per page.
Defaults to 30 in the GitHub API."""
@validator("since", allow_reuse=True)
@field_validator("since")
@classmethod
def validate_since(cls, v: Optional[str]) -> Optional[str]:
if v:
try:

View File

@@ -13,7 +13,6 @@ from pydantic import (
Field,
PrivateAttr,
model_validator,
validator,
)
__all__ = ["Databricks"]
@@ -414,18 +413,21 @@ class Databricks(LLM):
params["max_tokens"] = self.max_tokens
return params
@validator("cluster_id", always=True)
def set_cluster_id(cls, v: Any, values: Dict[str, Any]) -> Optional[str]:
if v and values["endpoint_name"]:
@model_validator(mode="before")
@classmethod
def set_cluster_id(cls, values: Dict[str, Any]) -> dict:
cluster_id = values.get("cluster_id")
endpoint_name = values.get("endpoint_name")
if cluster_id and endpoint_name:
raise ValueError("Cannot set both endpoint_name and cluster_id.")
elif values["endpoint_name"]:
return None
elif v:
return v
elif endpoint_name:
values["cluster_id"] = None
elif cluster_id:
pass
else:
try:
if v := get_repl_context().clusterId:
return v
if context_cluster_id := get_repl_context().clusterId:
values["cluster_id"] = context_cluster_id
raise ValueError("Context doesn't contain clusterId.")
except Exception as e:
raise ValueError(
@@ -434,27 +436,28 @@ class Databricks(LLM):
f" error: {e}"
)
@validator("cluster_driver_port", always=True)
def set_cluster_driver_port(cls, v: Any, values: Dict[str, Any]) -> Optional[str]:
if v and values["endpoint_name"]:
cluster_driver_port = values.get("cluster_driver_port")
if cluster_driver_port and endpoint_name:
raise ValueError("Cannot set both endpoint_name and cluster_driver_port.")
elif values["endpoint_name"]:
return None
elif v is None:
elif endpoint_name:
values["cluster_driver_port"] = None
elif cluster_driver_port is None:
raise ValueError(
"Must set cluster_driver_port to connect to a cluster driver."
)
elif int(v) <= 0:
raise ValueError(f"Invalid cluster_driver_port: {v}")
elif int(cluster_driver_port) <= 0:
raise ValueError(f"Invalid cluster_driver_port: {cluster_driver_port}")
else:
return v
pass
@validator("model_kwargs", always=True)
def set_model_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
if v:
assert "prompt" not in v, "model_kwargs must not contain key 'prompt'"
assert "stop" not in v, "model_kwargs must not contain key 'stop'"
return v
if model_kwargs := values.get("model_kwargs"):
assert (
"prompt" not in model_kwargs
), "model_kwargs must not contain key 'prompt'"
assert (
"stop" not in model_kwargs
), "model_kwargs must not contain key 'stop'"
return values
def __init__(self, **data: Any):
if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]):

View File

@@ -385,6 +385,10 @@ class AmazonKendraRetriever(BaseRetriever):
@model_validator(mode="before")
@classmethod
def create_client(cls, values: Dict[str, Any]) -> Any:
top_k = values.get("top_k")
if top_k is not None and top_k < 0:
raise ValueError(f"top_k ({top_k}) cannot be negative.")
if values.get("client") is not None:
return values

View File

@@ -7,7 +7,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, model_validator
from langchain_community.tools.playwright.base import BaseBrowserTool
from langchain_community.tools.playwright.utils import (
@@ -21,13 +21,15 @@ class NavigateToolInput(BaseModel):
url: str = Field(..., description="url to navigate to")
@validator("url")
def validate_url_scheme(cls, url: str) -> str:
@model_validator(mode="before")
@classmethod
def validate_url_scheme(cls, values: dict) -> dict:
"""Check that the URL scheme is valid."""
url = values.get("url")
parsed_url = urlparse(url)
if parsed_url.scheme not in ("http", "https"):
raise ValueError("URL scheme must be 'http' or 'https'")
return url
return values
class NavigateTool(BaseBrowserTool):

View File

@@ -9,7 +9,7 @@ from langchain_core.callbacks import (
CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool
from pydantic import ConfigDict, Field, validator
from pydantic import ConfigDict, Field, model_validator
from langchain_community.chat_models.openai import _import_tiktoken
from langchain_community.tools.powerbi.prompt import (
@@ -43,18 +43,20 @@ class QueryPowerBITool(BaseTool):
arbitrary_types_allowed=True,
)
@validator("llm_chain")
@model_validator(mode="before")
@classmethod
def validate_llm_chain_input_variables( # pylint: disable=E0213
cls, llm_chain: Any
) -> Any:
cls, values: dict
) -> dict:
"""Make sure the LLM chain has the correct input variables."""
llm_chain = values["llm_chain"]
for var in llm_chain.prompt.input_variables:
if var not in ["tool_input", "tables", "schemas", "examples"]:
raise ValueError(
"LLM chain for QueryPowerBITool must have input variables ['tool_input', 'tables', 'schemas', 'examples'], found %s", # noqa: E501 # pylint: disable=C0301
llm_chain.prompt.input_variables,
)
return llm_chain
return values
def _check_cache(self, tool_input: str) -> Optional[str]:
"""Check if the input is present in the cache.

View File

@@ -15,7 +15,6 @@ from pydantic import (
ConfigDict,
Field,
model_validator,
validator,
)
from requests.exceptions import Timeout
@@ -50,15 +49,12 @@ class PowerBIDataset(BaseModel):
arbitrary_types_allowed=True,
)
@validator("table_names", allow_reuse=True)
def fix_table_names(cls, table_names: List[str]) -> List[str]:
"""Fix the table names."""
return [fix_table_name(table) for table in table_names]
@model_validator(mode="before")
@classmethod
def token_or_credential_present(cls, values: Dict[str, Any]) -> Any:
def validate_params(cls, values: Dict[str, Any]) -> Any:
"""Validate that at least one of token and credentials is present."""
table_names = values.get("table_names", [])
values["table_names"] = [fix_table_name(table) for table in table_names]
if "token" in values or "credential" in values:
return values
raise ValueError("Please provide either a credential or a token.")

View File

@@ -66,7 +66,9 @@ markers = [
]
asyncio_mode = "auto"
filterwarnings = [
"ignore::langchain_core._api.beta_decorator.LangChainBetaWarning",
"ignore::langchain_core._api.deprecation.LangChainDeprecationWarning:test",
"ignore::langchain_core._api.deprecation.LangChainPendingDeprecationWarning:test",
]
[tool.poetry.group.test]