mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 19:11:33 +00:00
core[patch]: update some root_validators (#22787)
Update some of the @root_validators to be explicit pre=True or pre=False, skip_on_failure=True for pydantic 2 compatibility.
This commit is contained in:
parent
3d6e8547f9
commit
74e705250f
@ -68,7 +68,7 @@ class AIMessage(BaseMessage):
|
||||
"invalid_tool_calls": self.invalid_tool_calls,
|
||||
}
|
||||
|
||||
@root_validator()
|
||||
@root_validator(pre=True)
|
||||
def _backwards_compat_tool_calls(cls, values: dict) -> dict:
|
||||
raw_tool_calls = values.get("additional_kwargs", {}).get("tool_calls")
|
||||
tool_calls = (
|
||||
|
@ -59,6 +59,29 @@ class BasePromptTemplate(
|
||||
tags: Optional[List[str]] = None
|
||||
"""Tags to be used for tracing."""
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_variable_names(cls, values: Dict) -> Dict:
|
||||
"""Validate variable names do not include restricted names."""
|
||||
if "stop" in values["input_variables"]:
|
||||
raise ValueError(
|
||||
"Cannot have an input variable named 'stop', as it is used internally,"
|
||||
" please rename."
|
||||
)
|
||||
if "stop" in values["partial_variables"]:
|
||||
raise ValueError(
|
||||
"Cannot have an partial variable named 'stop', as it is used "
|
||||
"internally, please rename."
|
||||
)
|
||||
|
||||
overall = set(values["input_variables"]).intersection(
|
||||
values["partial_variables"]
|
||||
)
|
||||
if overall:
|
||||
raise ValueError(
|
||||
f"Found overlapping input and partial variables: {overall}"
|
||||
)
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
@ -155,29 +178,6 @@ class BasePromptTemplate(
|
||||
"""Create Prompt Value."""
|
||||
return self.format_prompt(**kwargs)
|
||||
|
||||
@root_validator()
|
||||
def validate_variable_names(cls, values: Dict) -> Dict:
|
||||
"""Validate variable names do not include restricted names."""
|
||||
if "stop" in values["input_variables"]:
|
||||
raise ValueError(
|
||||
"Cannot have an input variable named 'stop', as it is used internally,"
|
||||
" please rename."
|
||||
)
|
||||
if "stop" in values["partial_variables"]:
|
||||
raise ValueError(
|
||||
"Cannot have an partial variable named 'stop', as it is used "
|
||||
"internally, please rename."
|
||||
)
|
||||
|
||||
overall = set(values["input_variables"]).intersection(
|
||||
values["partial_variables"]
|
||||
)
|
||||
if overall:
|
||||
raise ValueError(
|
||||
f"Found overlapping input and partial variables: {overall}"
|
||||
)
|
||||
return values
|
||||
|
||||
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
|
||||
"""Return a partial of the prompt template."""
|
||||
prompt_dict = self.__dict__.copy()
|
||||
|
@ -309,7 +309,7 @@ class ChildTool(BaseTool):
|
||||
}
|
||||
return tool_input
|
||||
|
||||
@root_validator()
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
|
@ -772,17 +772,17 @@ class VectorStoreRetriever(BaseRetriever):
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator()
|
||||
@root_validator(pre=True)
|
||||
def validate_search_type(cls, values: Dict) -> Dict:
|
||||
"""Validate search type."""
|
||||
search_type = values["search_type"]
|
||||
search_type = values.get("search_type", "similarity")
|
||||
if search_type not in cls.allowed_search_types:
|
||||
raise ValueError(
|
||||
f"search_type of {search_type} not allowed. Valid values are: "
|
||||
f"{cls.allowed_search_types}"
|
||||
)
|
||||
if search_type == "similarity_score_threshold":
|
||||
score_threshold = values["search_kwargs"].get("score_threshold")
|
||||
score_threshold = values.get("search_kwargs", {}).get("score_threshold")
|
||||
if (score_threshold is None) or (not isinstance(score_threshold, float)):
|
||||
raise ValueError(
|
||||
"`score_threshold` is not specified with a float value(0~1) "
|
||||
|
@ -23,7 +23,7 @@ class MyRunnable(RunnableSerializable[str, str]):
|
||||
raise ValueError("Cannot set _my_hidden_property")
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
values["_my_hidden_property"] = values["my_property"]
|
||||
return values
|
||||
|
Loading…
Reference in New Issue
Block a user