diff --git a/libs/core/langchain_core/messages/ai.py b/libs/core/langchain_core/messages/ai.py index 130b657d48d..fe0d4bedb95 100644 --- a/libs/core/langchain_core/messages/ai.py +++ b/libs/core/langchain_core/messages/ai.py @@ -68,7 +68,7 @@ class AIMessage(BaseMessage): "invalid_tool_calls": self.invalid_tool_calls, } - @root_validator() + @root_validator(pre=True) def _backwards_compat_tool_calls(cls, values: dict) -> dict: raw_tool_calls = values.get("additional_kwargs", {}).get("tool_calls") tool_calls = ( diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index 381ce854d3e..147b61e6a19 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -59,6 +59,29 @@ class BasePromptTemplate( tags: Optional[List[str]] = None """Tags to be used for tracing.""" + @root_validator(pre=False, skip_on_failure=True) + def validate_variable_names(cls, values: Dict) -> Dict: + """Validate variable names do not include restricted names.""" + if "stop" in values["input_variables"]: + raise ValueError( + "Cannot have an input variable named 'stop', as it is used internally," + " please rename." + ) + if "stop" in values["partial_variables"]: + raise ValueError( + "Cannot have an partial variable named 'stop', as it is used " + "internally, please rename." + ) + + overall = set(values["input_variables"]).intersection( + values["partial_variables"] + ) + if overall: + raise ValueError( + f"Found overlapping input and partial variables: {overall}" + ) + return values + @classmethod def get_lc_namespace(cls) -> List[str]: """Get the namespace of the langchain object.""" @@ -155,29 +178,6 @@ class BasePromptTemplate( """Create Prompt Value.""" return self.format_prompt(**kwargs) - @root_validator() - def validate_variable_names(cls, values: Dict) -> Dict: - """Validate variable names do not include restricted names.""" - if "stop" in values["input_variables"]: - raise ValueError( - "Cannot have an input variable named 'stop', as it is used internally," - " please rename." - ) - if "stop" in values["partial_variables"]: - raise ValueError( - "Cannot have an partial variable named 'stop', as it is used " - "internally, please rename." - ) - - overall = set(values["input_variables"]).intersection( - values["partial_variables"] - ) - if overall: - raise ValueError( - f"Found overlapping input and partial variables: {overall}" - ) - return values - def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate: """Return a partial of the prompt template.""" prompt_dict = self.__dict__.copy() diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index 26348fc7b52..86cedd51aa9 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -309,7 +309,7 @@ class ChildTool(BaseTool): } return tool_input - @root_validator() + @root_validator(pre=True) def raise_deprecation(cls, values: Dict) -> Dict: """Raise deprecation warning if callback_manager is used.""" if values.get("callback_manager") is not None: diff --git a/libs/core/langchain_core/vectorstores.py b/libs/core/langchain_core/vectorstores.py index 04b281262f5..9098538590d 100644 --- a/libs/core/langchain_core/vectorstores.py +++ b/libs/core/langchain_core/vectorstores.py @@ -772,17 +772,17 @@ class VectorStoreRetriever(BaseRetriever): arbitrary_types_allowed = True - @root_validator() + @root_validator(pre=True) def validate_search_type(cls, values: Dict) -> Dict: """Validate search type.""" - search_type = values["search_type"] + search_type = values.get("search_type", "similarity") if search_type not in cls.allowed_search_types: raise ValueError( f"search_type of {search_type} not allowed. Valid values are: " f"{cls.allowed_search_types}" ) if search_type == "similarity_score_threshold": - score_threshold = values["search_kwargs"].get("score_threshold") + score_threshold = values.get("search_kwargs", {}).get("score_threshold") if (score_threshold is None) or (not isinstance(score_threshold, float)): raise ValueError( "`score_threshold` is not specified with a float value(0~1) " diff --git a/libs/core/tests/unit_tests/runnables/test_configurable.py b/libs/core/tests/unit_tests/runnables/test_configurable.py index c5d74df5ee9..70712494613 100644 --- a/libs/core/tests/unit_tests/runnables/test_configurable.py +++ b/libs/core/tests/unit_tests/runnables/test_configurable.py @@ -23,7 +23,7 @@ class MyRunnable(RunnableSerializable[str, str]): raise ValueError("Cannot set _my_hidden_property") return values - @root_validator() + @root_validator(pre=False, skip_on_failure=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["_my_hidden_property"] = values["my_property"] return values