mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-27 14:26:48 +00:00
community[patch]: fix community warnings 1 (#26239)
This commit is contained in:
@@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union
|
|||||||
import requests
|
import requests
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.utils import get_from_dict_or_env
|
from langchain_core.utils import get_from_dict_or_env
|
||||||
from pydantic import BaseModel, model_validator, validator
|
from pydantic import BaseModel, field_validator, model_validator
|
||||||
|
|
||||||
from langchain_community.document_loaders.base import BaseLoader
|
from langchain_community.document_loaders.base import BaseLoader
|
||||||
|
|
||||||
@@ -73,7 +73,8 @@ class GitHubIssuesLoader(BaseGitHubLoader):
|
|||||||
"""Number of items per page.
|
"""Number of items per page.
|
||||||
Defaults to 30 in the GitHub API."""
|
Defaults to 30 in the GitHub API."""
|
||||||
|
|
||||||
@validator("since", allow_reuse=True)
|
@field_validator("since")
|
||||||
|
@classmethod
|
||||||
def validate_since(cls, v: Optional[str]) -> Optional[str]:
|
def validate_since(cls, v: Optional[str]) -> Optional[str]:
|
||||||
if v:
|
if v:
|
||||||
try:
|
try:
|
||||||
|
@@ -13,7 +13,6 @@ from pydantic import (
|
|||||||
Field,
|
Field,
|
||||||
PrivateAttr,
|
PrivateAttr,
|
||||||
model_validator,
|
model_validator,
|
||||||
validator,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = ["Databricks"]
|
__all__ = ["Databricks"]
|
||||||
@@ -414,18 +413,21 @@ class Databricks(LLM):
|
|||||||
params["max_tokens"] = self.max_tokens
|
params["max_tokens"] = self.max_tokens
|
||||||
return params
|
return params
|
||||||
|
|
||||||
@validator("cluster_id", always=True)
|
@model_validator(mode="before")
|
||||||
def set_cluster_id(cls, v: Any, values: Dict[str, Any]) -> Optional[str]:
|
@classmethod
|
||||||
if v and values["endpoint_name"]:
|
def set_cluster_id(cls, values: Dict[str, Any]) -> dict:
|
||||||
|
cluster_id = values.get("cluster_id")
|
||||||
|
endpoint_name = values.get("endpoint_name")
|
||||||
|
if cluster_id and endpoint_name:
|
||||||
raise ValueError("Cannot set both endpoint_name and cluster_id.")
|
raise ValueError("Cannot set both endpoint_name and cluster_id.")
|
||||||
elif values["endpoint_name"]:
|
elif endpoint_name:
|
||||||
return None
|
values["cluster_id"] = None
|
||||||
elif v:
|
elif cluster_id:
|
||||||
return v
|
pass
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
if v := get_repl_context().clusterId:
|
if context_cluster_id := get_repl_context().clusterId:
|
||||||
return v
|
values["cluster_id"] = context_cluster_id
|
||||||
raise ValueError("Context doesn't contain clusterId.")
|
raise ValueError("Context doesn't contain clusterId.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -434,27 +436,28 @@ class Databricks(LLM):
|
|||||||
f" error: {e}"
|
f" error: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@validator("cluster_driver_port", always=True)
|
cluster_driver_port = values.get("cluster_driver_port")
|
||||||
def set_cluster_driver_port(cls, v: Any, values: Dict[str, Any]) -> Optional[str]:
|
if cluster_driver_port and endpoint_name:
|
||||||
if v and values["endpoint_name"]:
|
|
||||||
raise ValueError("Cannot set both endpoint_name and cluster_driver_port.")
|
raise ValueError("Cannot set both endpoint_name and cluster_driver_port.")
|
||||||
elif values["endpoint_name"]:
|
elif endpoint_name:
|
||||||
return None
|
values["cluster_driver_port"] = None
|
||||||
elif v is None:
|
elif cluster_driver_port is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Must set cluster_driver_port to connect to a cluster driver."
|
"Must set cluster_driver_port to connect to a cluster driver."
|
||||||
)
|
)
|
||||||
elif int(v) <= 0:
|
elif int(cluster_driver_port) <= 0:
|
||||||
raise ValueError(f"Invalid cluster_driver_port: {v}")
|
raise ValueError(f"Invalid cluster_driver_port: {cluster_driver_port}")
|
||||||
else:
|
else:
|
||||||
return v
|
pass
|
||||||
|
|
||||||
@validator("model_kwargs", always=True)
|
if model_kwargs := values.get("model_kwargs"):
|
||||||
def set_model_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
assert (
|
||||||
if v:
|
"prompt" not in model_kwargs
|
||||||
assert "prompt" not in v, "model_kwargs must not contain key 'prompt'"
|
), "model_kwargs must not contain key 'prompt'"
|
||||||
assert "stop" not in v, "model_kwargs must not contain key 'stop'"
|
assert (
|
||||||
return v
|
"stop" not in model_kwargs
|
||||||
|
), "model_kwargs must not contain key 'stop'"
|
||||||
|
return values
|
||||||
|
|
||||||
def __init__(self, **data: Any):
|
def __init__(self, **data: Any):
|
||||||
if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]):
|
if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]):
|
||||||
|
@@ -385,6 +385,10 @@ class AmazonKendraRetriever(BaseRetriever):
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_client(cls, values: Dict[str, Any]) -> Any:
|
def create_client(cls, values: Dict[str, Any]) -> Any:
|
||||||
|
top_k = values.get("top_k")
|
||||||
|
if top_k is not None and top_k < 0:
|
||||||
|
raise ValueError(f"top_k ({top_k}) cannot be negative.")
|
||||||
|
|
||||||
if values.get("client") is not None:
|
if values.get("client") is not None:
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
@@ -7,7 +7,7 @@ from langchain_core.callbacks import (
|
|||||||
AsyncCallbackManagerForToolRun,
|
AsyncCallbackManagerForToolRun,
|
||||||
CallbackManagerForToolRun,
|
CallbackManagerForToolRun,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
from langchain_community.tools.playwright.base import BaseBrowserTool
|
from langchain_community.tools.playwright.base import BaseBrowserTool
|
||||||
from langchain_community.tools.playwright.utils import (
|
from langchain_community.tools.playwright.utils import (
|
||||||
@@ -21,13 +21,15 @@ class NavigateToolInput(BaseModel):
|
|||||||
|
|
||||||
url: str = Field(..., description="url to navigate to")
|
url: str = Field(..., description="url to navigate to")
|
||||||
|
|
||||||
@validator("url")
|
@model_validator(mode="before")
|
||||||
def validate_url_scheme(cls, url: str) -> str:
|
@classmethod
|
||||||
|
def validate_url_scheme(cls, values: dict) -> dict:
|
||||||
"""Check that the URL scheme is valid."""
|
"""Check that the URL scheme is valid."""
|
||||||
|
url = values.get("url")
|
||||||
parsed_url = urlparse(url)
|
parsed_url = urlparse(url)
|
||||||
if parsed_url.scheme not in ("http", "https"):
|
if parsed_url.scheme not in ("http", "https"):
|
||||||
raise ValueError("URL scheme must be 'http' or 'https'")
|
raise ValueError("URL scheme must be 'http' or 'https'")
|
||||||
return url
|
return values
|
||||||
|
|
||||||
|
|
||||||
class NavigateTool(BaseBrowserTool):
|
class NavigateTool(BaseBrowserTool):
|
||||||
|
@@ -9,7 +9,7 @@ from langchain_core.callbacks import (
|
|||||||
CallbackManagerForToolRun,
|
CallbackManagerForToolRun,
|
||||||
)
|
)
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from pydantic import ConfigDict, Field, validator
|
from pydantic import ConfigDict, Field, model_validator
|
||||||
|
|
||||||
from langchain_community.chat_models.openai import _import_tiktoken
|
from langchain_community.chat_models.openai import _import_tiktoken
|
||||||
from langchain_community.tools.powerbi.prompt import (
|
from langchain_community.tools.powerbi.prompt import (
|
||||||
@@ -43,18 +43,20 @@ class QueryPowerBITool(BaseTool):
|
|||||||
arbitrary_types_allowed=True,
|
arbitrary_types_allowed=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@validator("llm_chain")
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
def validate_llm_chain_input_variables( # pylint: disable=E0213
|
def validate_llm_chain_input_variables( # pylint: disable=E0213
|
||||||
cls, llm_chain: Any
|
cls, values: dict
|
||||||
) -> Any:
|
) -> dict:
|
||||||
"""Make sure the LLM chain has the correct input variables."""
|
"""Make sure the LLM chain has the correct input variables."""
|
||||||
|
llm_chain = values["llm_chain"]
|
||||||
for var in llm_chain.prompt.input_variables:
|
for var in llm_chain.prompt.input_variables:
|
||||||
if var not in ["tool_input", "tables", "schemas", "examples"]:
|
if var not in ["tool_input", "tables", "schemas", "examples"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"LLM chain for QueryPowerBITool must have input variables ['tool_input', 'tables', 'schemas', 'examples'], found %s", # noqa: E501 # pylint: disable=C0301
|
"LLM chain for QueryPowerBITool must have input variables ['tool_input', 'tables', 'schemas', 'examples'], found %s", # noqa: E501 # pylint: disable=C0301
|
||||||
llm_chain.prompt.input_variables,
|
llm_chain.prompt.input_variables,
|
||||||
)
|
)
|
||||||
return llm_chain
|
return values
|
||||||
|
|
||||||
def _check_cache(self, tool_input: str) -> Optional[str]:
|
def _check_cache(self, tool_input: str) -> Optional[str]:
|
||||||
"""Check if the input is present in the cache.
|
"""Check if the input is present in the cache.
|
||||||
|
@@ -15,7 +15,6 @@ from pydantic import (
|
|||||||
ConfigDict,
|
ConfigDict,
|
||||||
Field,
|
Field,
|
||||||
model_validator,
|
model_validator,
|
||||||
validator,
|
|
||||||
)
|
)
|
||||||
from requests.exceptions import Timeout
|
from requests.exceptions import Timeout
|
||||||
|
|
||||||
@@ -50,15 +49,12 @@ class PowerBIDataset(BaseModel):
|
|||||||
arbitrary_types_allowed=True,
|
arbitrary_types_allowed=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@validator("table_names", allow_reuse=True)
|
|
||||||
def fix_table_names(cls, table_names: List[str]) -> List[str]:
|
|
||||||
"""Fix the table names."""
|
|
||||||
return [fix_table_name(table) for table in table_names]
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def token_or_credential_present(cls, values: Dict[str, Any]) -> Any:
|
def validate_params(cls, values: Dict[str, Any]) -> Any:
|
||||||
"""Validate that at least one of token and credentials is present."""
|
"""Validate that at least one of token and credentials is present."""
|
||||||
|
table_names = values.get("table_names", [])
|
||||||
|
values["table_names"] = [fix_table_name(table) for table in table_names]
|
||||||
if "token" in values or "credential" in values:
|
if "token" in values or "credential" in values:
|
||||||
return values
|
return values
|
||||||
raise ValueError("Please provide either a credential or a token.")
|
raise ValueError("Please provide either a credential or a token.")
|
||||||
|
@@ -66,7 +66,9 @@ markers = [
|
|||||||
]
|
]
|
||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
filterwarnings = [
|
filterwarnings = [
|
||||||
|
"ignore::langchain_core._api.beta_decorator.LangChainBetaWarning",
|
||||||
"ignore::langchain_core._api.deprecation.LangChainDeprecationWarning:test",
|
"ignore::langchain_core._api.deprecation.LangChainDeprecationWarning:test",
|
||||||
|
"ignore::langchain_core._api.deprecation.LangChainPendingDeprecationWarning:test",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.poetry.group.test]
|
[tool.poetry.group.test]
|
||||||
|
Reference in New Issue
Block a user