mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-26 13:59:49 +00:00
community[patch]: fix community warnings 1 (#26239)
This commit is contained in:
@@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union
|
||||
import requests
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from pydantic import BaseModel, model_validator, validator
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
|
||||
from langchain_community.document_loaders.base import BaseLoader
|
||||
|
||||
@@ -73,7 +73,8 @@ class GitHubIssuesLoader(BaseGitHubLoader):
|
||||
"""Number of items per page.
|
||||
Defaults to 30 in the GitHub API."""
|
||||
|
||||
@validator("since", allow_reuse=True)
|
||||
@field_validator("since")
|
||||
@classmethod
|
||||
def validate_since(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v:
|
||||
try:
|
||||
|
@@ -13,7 +13,6 @@ from pydantic import (
|
||||
Field,
|
||||
PrivateAttr,
|
||||
model_validator,
|
||||
validator,
|
||||
)
|
||||
|
||||
__all__ = ["Databricks"]
|
||||
@@ -414,18 +413,21 @@ class Databricks(LLM):
|
||||
params["max_tokens"] = self.max_tokens
|
||||
return params
|
||||
|
||||
@validator("cluster_id", always=True)
|
||||
def set_cluster_id(cls, v: Any, values: Dict[str, Any]) -> Optional[str]:
|
||||
if v and values["endpoint_name"]:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_cluster_id(cls, values: Dict[str, Any]) -> dict:
|
||||
cluster_id = values.get("cluster_id")
|
||||
endpoint_name = values.get("endpoint_name")
|
||||
if cluster_id and endpoint_name:
|
||||
raise ValueError("Cannot set both endpoint_name and cluster_id.")
|
||||
elif values["endpoint_name"]:
|
||||
return None
|
||||
elif v:
|
||||
return v
|
||||
elif endpoint_name:
|
||||
values["cluster_id"] = None
|
||||
elif cluster_id:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
if v := get_repl_context().clusterId:
|
||||
return v
|
||||
if context_cluster_id := get_repl_context().clusterId:
|
||||
values["cluster_id"] = context_cluster_id
|
||||
raise ValueError("Context doesn't contain clusterId.")
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
@@ -434,27 +436,28 @@ class Databricks(LLM):
|
||||
f" error: {e}"
|
||||
)
|
||||
|
||||
@validator("cluster_driver_port", always=True)
|
||||
def set_cluster_driver_port(cls, v: Any, values: Dict[str, Any]) -> Optional[str]:
|
||||
if v and values["endpoint_name"]:
|
||||
cluster_driver_port = values.get("cluster_driver_port")
|
||||
if cluster_driver_port and endpoint_name:
|
||||
raise ValueError("Cannot set both endpoint_name and cluster_driver_port.")
|
||||
elif values["endpoint_name"]:
|
||||
return None
|
||||
elif v is None:
|
||||
elif endpoint_name:
|
||||
values["cluster_driver_port"] = None
|
||||
elif cluster_driver_port is None:
|
||||
raise ValueError(
|
||||
"Must set cluster_driver_port to connect to a cluster driver."
|
||||
)
|
||||
elif int(v) <= 0:
|
||||
raise ValueError(f"Invalid cluster_driver_port: {v}")
|
||||
elif int(cluster_driver_port) <= 0:
|
||||
raise ValueError(f"Invalid cluster_driver_port: {cluster_driver_port}")
|
||||
else:
|
||||
return v
|
||||
pass
|
||||
|
||||
@validator("model_kwargs", always=True)
|
||||
def set_model_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
if v:
|
||||
assert "prompt" not in v, "model_kwargs must not contain key 'prompt'"
|
||||
assert "stop" not in v, "model_kwargs must not contain key 'stop'"
|
||||
return v
|
||||
if model_kwargs := values.get("model_kwargs"):
|
||||
assert (
|
||||
"prompt" not in model_kwargs
|
||||
), "model_kwargs must not contain key 'prompt'"
|
||||
assert (
|
||||
"stop" not in model_kwargs
|
||||
), "model_kwargs must not contain key 'stop'"
|
||||
return values
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]):
|
||||
|
@@ -385,6 +385,10 @@ class AmazonKendraRetriever(BaseRetriever):
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_client(cls, values: Dict[str, Any]) -> Any:
|
||||
top_k = values.get("top_k")
|
||||
if top_k is not None and top_k < 0:
|
||||
raise ValueError(f"top_k ({top_k}) cannot be negative.")
|
||||
|
||||
if values.get("client") is not None:
|
||||
return values
|
||||
|
||||
|
@@ -7,7 +7,7 @@ from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from langchain_community.tools.playwright.base import BaseBrowserTool
|
||||
from langchain_community.tools.playwright.utils import (
|
||||
@@ -21,13 +21,15 @@ class NavigateToolInput(BaseModel):
|
||||
|
||||
url: str = Field(..., description="url to navigate to")
|
||||
|
||||
@validator("url")
|
||||
def validate_url_scheme(cls, url: str) -> str:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_url_scheme(cls, values: dict) -> dict:
|
||||
"""Check that the URL scheme is valid."""
|
||||
url = values.get("url")
|
||||
parsed_url = urlparse(url)
|
||||
if parsed_url.scheme not in ("http", "https"):
|
||||
raise ValueError("URL scheme must be 'http' or 'https'")
|
||||
return url
|
||||
return values
|
||||
|
||||
|
||||
class NavigateTool(BaseBrowserTool):
|
||||
|
@@ -9,7 +9,7 @@ from langchain_core.callbacks import (
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.tools import BaseTool
|
||||
from pydantic import ConfigDict, Field, validator
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
|
||||
from langchain_community.chat_models.openai import _import_tiktoken
|
||||
from langchain_community.tools.powerbi.prompt import (
|
||||
@@ -43,18 +43,20 @@ class QueryPowerBITool(BaseTool):
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@validator("llm_chain")
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_llm_chain_input_variables( # pylint: disable=E0213
|
||||
cls, llm_chain: Any
|
||||
) -> Any:
|
||||
cls, values: dict
|
||||
) -> dict:
|
||||
"""Make sure the LLM chain has the correct input variables."""
|
||||
llm_chain = values["llm_chain"]
|
||||
for var in llm_chain.prompt.input_variables:
|
||||
if var not in ["tool_input", "tables", "schemas", "examples"]:
|
||||
raise ValueError(
|
||||
"LLM chain for QueryPowerBITool must have input variables ['tool_input', 'tables', 'schemas', 'examples'], found %s", # noqa: E501 # pylint: disable=C0301
|
||||
llm_chain.prompt.input_variables,
|
||||
)
|
||||
return llm_chain
|
||||
return values
|
||||
|
||||
def _check_cache(self, tool_input: str) -> Optional[str]:
|
||||
"""Check if the input is present in the cache.
|
||||
|
@@ -15,7 +15,6 @@ from pydantic import (
|
||||
ConfigDict,
|
||||
Field,
|
||||
model_validator,
|
||||
validator,
|
||||
)
|
||||
from requests.exceptions import Timeout
|
||||
|
||||
@@ -50,15 +49,12 @@ class PowerBIDataset(BaseModel):
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@validator("table_names", allow_reuse=True)
|
||||
def fix_table_names(cls, table_names: List[str]) -> List[str]:
|
||||
"""Fix the table names."""
|
||||
return [fix_table_name(table) for table in table_names]
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def token_or_credential_present(cls, values: Dict[str, Any]) -> Any:
|
||||
def validate_params(cls, values: Dict[str, Any]) -> Any:
|
||||
"""Validate that at least one of token and credentials is present."""
|
||||
table_names = values.get("table_names", [])
|
||||
values["table_names"] = [fix_table_name(table) for table in table_names]
|
||||
if "token" in values or "credential" in values:
|
||||
return values
|
||||
raise ValueError("Please provide either a credential or a token.")
|
||||
|
@@ -66,7 +66,9 @@ markers = [
|
||||
]
|
||||
asyncio_mode = "auto"
|
||||
filterwarnings = [
|
||||
"ignore::langchain_core._api.beta_decorator.LangChainBetaWarning",
|
||||
"ignore::langchain_core._api.deprecation.LangChainDeprecationWarning:test",
|
||||
"ignore::langchain_core._api.deprecation.LangChainPendingDeprecationWarning:test",
|
||||
]
|
||||
|
||||
[tool.poetry.group.test]
|
||||
|
Reference in New Issue
Block a user