mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-28 15:00:23 +00:00
community[patch]: fix community warnings 1 (#26239)
This commit is contained in:
@@ -13,7 +13,6 @@ from pydantic import (
|
||||
Field,
|
||||
PrivateAttr,
|
||||
model_validator,
|
||||
validator,
|
||||
)
|
||||
|
||||
__all__ = ["Databricks"]
|
||||
@@ -414,18 +413,21 @@ class Databricks(LLM):
|
||||
params["max_tokens"] = self.max_tokens
|
||||
return params
|
||||
|
||||
@validator("cluster_id", always=True)
|
||||
def set_cluster_id(cls, v: Any, values: Dict[str, Any]) -> Optional[str]:
|
||||
if v and values["endpoint_name"]:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_cluster_id(cls, values: Dict[str, Any]) -> dict:
|
||||
cluster_id = values.get("cluster_id")
|
||||
endpoint_name = values.get("endpoint_name")
|
||||
if cluster_id and endpoint_name:
|
||||
raise ValueError("Cannot set both endpoint_name and cluster_id.")
|
||||
elif values["endpoint_name"]:
|
||||
return None
|
||||
elif v:
|
||||
return v
|
||||
elif endpoint_name:
|
||||
values["cluster_id"] = None
|
||||
elif cluster_id:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
if v := get_repl_context().clusterId:
|
||||
return v
|
||||
if context_cluster_id := get_repl_context().clusterId:
|
||||
values["cluster_id"] = context_cluster_id
|
||||
raise ValueError("Context doesn't contain clusterId.")
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
@@ -434,27 +436,28 @@ class Databricks(LLM):
|
||||
f" error: {e}"
|
||||
)
|
||||
|
||||
@validator("cluster_driver_port", always=True)
|
||||
def set_cluster_driver_port(cls, v: Any, values: Dict[str, Any]) -> Optional[str]:
|
||||
if v and values["endpoint_name"]:
|
||||
cluster_driver_port = values.get("cluster_driver_port")
|
||||
if cluster_driver_port and endpoint_name:
|
||||
raise ValueError("Cannot set both endpoint_name and cluster_driver_port.")
|
||||
elif values["endpoint_name"]:
|
||||
return None
|
||||
elif v is None:
|
||||
elif endpoint_name:
|
||||
values["cluster_driver_port"] = None
|
||||
elif cluster_driver_port is None:
|
||||
raise ValueError(
|
||||
"Must set cluster_driver_port to connect to a cluster driver."
|
||||
)
|
||||
elif int(v) <= 0:
|
||||
raise ValueError(f"Invalid cluster_driver_port: {v}")
|
||||
elif int(cluster_driver_port) <= 0:
|
||||
raise ValueError(f"Invalid cluster_driver_port: {cluster_driver_port}")
|
||||
else:
|
||||
return v
|
||||
pass
|
||||
|
||||
@validator("model_kwargs", always=True)
|
||||
def set_model_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
if v:
|
||||
assert "prompt" not in v, "model_kwargs must not contain key 'prompt'"
|
||||
assert "stop" not in v, "model_kwargs must not contain key 'stop'"
|
||||
return v
|
||||
if model_kwargs := values.get("model_kwargs"):
|
||||
assert (
|
||||
"prompt" not in model_kwargs
|
||||
), "model_kwargs must not contain key 'prompt'"
|
||||
assert (
|
||||
"stop" not in model_kwargs
|
||||
), "model_kwargs must not contain key 'stop'"
|
||||
return values
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]):
|
||||
|
Reference in New Issue
Block a user