mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 03:01:29 +00:00
core[patch]: update some root_validators (#22787)
Update some of the @root_validators to be explicit pre=True or pre=False, skip_on_failure=True for pydantic 2 compatibility.
This commit is contained in:
parent
3d6e8547f9
commit
74e705250f
@ -68,7 +68,7 @@ class AIMessage(BaseMessage):
|
|||||||
"invalid_tool_calls": self.invalid_tool_calls,
|
"invalid_tool_calls": self.invalid_tool_calls,
|
||||||
}
|
}
|
||||||
|
|
||||||
@root_validator()
|
@root_validator(pre=True)
|
||||||
def _backwards_compat_tool_calls(cls, values: dict) -> dict:
|
def _backwards_compat_tool_calls(cls, values: dict) -> dict:
|
||||||
raw_tool_calls = values.get("additional_kwargs", {}).get("tool_calls")
|
raw_tool_calls = values.get("additional_kwargs", {}).get("tool_calls")
|
||||||
tool_calls = (
|
tool_calls = (
|
||||||
|
@ -59,6 +59,29 @@ class BasePromptTemplate(
|
|||||||
tags: Optional[List[str]] = None
|
tags: Optional[List[str]] = None
|
||||||
"""Tags to be used for tracing."""
|
"""Tags to be used for tracing."""
|
||||||
|
|
||||||
|
@root_validator(pre=False, skip_on_failure=True)
|
||||||
|
def validate_variable_names(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate variable names do not include restricted names."""
|
||||||
|
if "stop" in values["input_variables"]:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot have an input variable named 'stop', as it is used internally,"
|
||||||
|
" please rename."
|
||||||
|
)
|
||||||
|
if "stop" in values["partial_variables"]:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot have an partial variable named 'stop', as it is used "
|
||||||
|
"internally, please rename."
|
||||||
|
)
|
||||||
|
|
||||||
|
overall = set(values["input_variables"]).intersection(
|
||||||
|
values["partial_variables"]
|
||||||
|
)
|
||||||
|
if overall:
|
||||||
|
raise ValueError(
|
||||||
|
f"Found overlapping input and partial variables: {overall}"
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> List[str]:
|
||||||
"""Get the namespace of the langchain object."""
|
"""Get the namespace of the langchain object."""
|
||||||
@ -155,29 +178,6 @@ class BasePromptTemplate(
|
|||||||
"""Create Prompt Value."""
|
"""Create Prompt Value."""
|
||||||
return self.format_prompt(**kwargs)
|
return self.format_prompt(**kwargs)
|
||||||
|
|
||||||
@root_validator()
|
|
||||||
def validate_variable_names(cls, values: Dict) -> Dict:
|
|
||||||
"""Validate variable names do not include restricted names."""
|
|
||||||
if "stop" in values["input_variables"]:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot have an input variable named 'stop', as it is used internally,"
|
|
||||||
" please rename."
|
|
||||||
)
|
|
||||||
if "stop" in values["partial_variables"]:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot have an partial variable named 'stop', as it is used "
|
|
||||||
"internally, please rename."
|
|
||||||
)
|
|
||||||
|
|
||||||
overall = set(values["input_variables"]).intersection(
|
|
||||||
values["partial_variables"]
|
|
||||||
)
|
|
||||||
if overall:
|
|
||||||
raise ValueError(
|
|
||||||
f"Found overlapping input and partial variables: {overall}"
|
|
||||||
)
|
|
||||||
return values
|
|
||||||
|
|
||||||
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
|
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
|
||||||
"""Return a partial of the prompt template."""
|
"""Return a partial of the prompt template."""
|
||||||
prompt_dict = self.__dict__.copy()
|
prompt_dict = self.__dict__.copy()
|
||||||
|
@ -309,7 +309,7 @@ class ChildTool(BaseTool):
|
|||||||
}
|
}
|
||||||
return tool_input
|
return tool_input
|
||||||
|
|
||||||
@root_validator()
|
@root_validator(pre=True)
|
||||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||||
"""Raise deprecation warning if callback_manager is used."""
|
"""Raise deprecation warning if callback_manager is used."""
|
||||||
if values.get("callback_manager") is not None:
|
if values.get("callback_manager") is not None:
|
||||||
|
@ -772,17 +772,17 @@ class VectorStoreRetriever(BaseRetriever):
|
|||||||
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@root_validator()
|
@root_validator(pre=True)
|
||||||
def validate_search_type(cls, values: Dict) -> Dict:
|
def validate_search_type(cls, values: Dict) -> Dict:
|
||||||
"""Validate search type."""
|
"""Validate search type."""
|
||||||
search_type = values["search_type"]
|
search_type = values.get("search_type", "similarity")
|
||||||
if search_type not in cls.allowed_search_types:
|
if search_type not in cls.allowed_search_types:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"search_type of {search_type} not allowed. Valid values are: "
|
f"search_type of {search_type} not allowed. Valid values are: "
|
||||||
f"{cls.allowed_search_types}"
|
f"{cls.allowed_search_types}"
|
||||||
)
|
)
|
||||||
if search_type == "similarity_score_threshold":
|
if search_type == "similarity_score_threshold":
|
||||||
score_threshold = values["search_kwargs"].get("score_threshold")
|
score_threshold = values.get("search_kwargs", {}).get("score_threshold")
|
||||||
if (score_threshold is None) or (not isinstance(score_threshold, float)):
|
if (score_threshold is None) or (not isinstance(score_threshold, float)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`score_threshold` is not specified with a float value(0~1) "
|
"`score_threshold` is not specified with a float value(0~1) "
|
||||||
|
@ -23,7 +23,7 @@ class MyRunnable(RunnableSerializable[str, str]):
|
|||||||
raise ValueError("Cannot set _my_hidden_property")
|
raise ValueError("Cannot set _my_hidden_property")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@root_validator()
|
@root_validator(pre=False, skip_on_failure=True)
|
||||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
values["_my_hidden_property"] = values["my_property"]
|
values["_my_hidden_property"] = values["my_property"]
|
||||||
return values
|
return values
|
||||||
|
Loading…
Reference in New Issue
Block a user