From 5dbbdcbf8e38886cf9281e297f19ab661a608b38 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 12 Jun 2024 14:47:40 -0400 Subject: [PATCH] core[patch]: Update remaining root_validators (#22829) This PR updates the remaining root_validators in core to either be explicit pre-init or post-init validators. --- .../language_models/chat_models.py | 2 +- .../langchain_core/language_models/llms.py | 2 +- libs/core/langchain_core/messages/ai.py | 2 +- libs/core/langchain_core/prompts/few_shot.py | 2 +- .../prompts/few_shot_with_templates.py | 2 +- libs/core/langchain_core/prompts/prompt.py | 59 ++++++++++++------- 6 files changed, 43 insertions(+), 26 deletions(-) diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index f968790e130..e68270ac6cb 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -120,7 +120,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) """[DEPRECATED] Callback manager to add to the run trace.""" - @root_validator() + @root_validator(pre=True) def raise_deprecation(cls, values: Dict) -> Dict: """Raise deprecation warning if callback_manager is used.""" if values.get("callback_manager") is not None: diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index 6ec473c9c07..f2c7bf89703 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -232,7 +232,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): arbitrary_types_allowed = True - @root_validator() + @root_validator(pre=True) def raise_deprecation(cls, values: Dict) -> Dict: """Raise deprecation warning if callback_manager is used.""" if values.get("callback_manager") is not None: diff --git a/libs/core/langchain_core/messages/ai.py b/libs/core/langchain_core/messages/ai.py index fe0d4bedb95..bf7e2c0747d 100644 --- a/libs/core/langchain_core/messages/ai.py +++ b/libs/core/langchain_core/messages/ai.py @@ -149,7 +149,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk): "invalid_tool_calls": self.invalid_tool_calls, } - @root_validator() + @root_validator(pre=False, skip_on_failure=True) def init_tool_calls(cls, values: dict) -> dict: if not values["tool_call_chunks"]: values["tool_calls"] = [] diff --git a/libs/core/langchain_core/prompts/few_shot.py b/libs/core/langchain_core/prompts/few_shot.py index 2fbc7d7a46e..800ee306544 100644 --- a/libs/core/langchain_core/prompts/few_shot.py +++ b/libs/core/langchain_core/prompts/few_shot.py @@ -121,7 +121,7 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate): template_format: Literal["f-string", "jinja2"] = "f-string" """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" - @root_validator() + @root_validator(pre=False, skip_on_failure=True) def template_is_valid(cls, values: Dict) -> Dict: """Check that prefix, suffix, and input variables are consistent.""" if values["validate_template"]: diff --git a/libs/core/langchain_core/prompts/few_shot_with_templates.py b/libs/core/langchain_core/prompts/few_shot_with_templates.py index a6724a179ad..19c2c941dec 100644 --- a/libs/core/langchain_core/prompts/few_shot_with_templates.py +++ b/libs/core/langchain_core/prompts/few_shot_with_templates.py @@ -64,7 +64,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate): return values - @root_validator() + @root_validator(pre=False, skip_on_failure=True) def template_is_valid(cls, values: Dict) -> Dict: """Check that prefix, suffix, and input variables are consistent.""" if values["validate_template"]: diff --git a/libs/core/langchain_core/prompts/prompt.py b/libs/core/langchain_core/prompts/prompt.py index 1dec3cb0f23..baa5f65e7ad 100644 --- a/libs/core/langchain_core/prompts/prompt.py +++ b/libs/core/langchain_core/prompts/prompt.py @@ -47,7 +47,7 @@ class PromptTemplate(StringPromptTemplate): prompt.format(foo="bar") # Instantiation using initializer - prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}") + prompt = PromptTemplate(template="Say {foo}") """ @property @@ -74,6 +74,43 @@ class PromptTemplate(StringPromptTemplate): validate_template: bool = False """Whether or not to try validating the template.""" + @root_validator(pre=True) + def pre_init_validation(cls, values: Dict) -> Dict: + """Check that template and input variables are consistent.""" + if values.get("template") is None: + # Will let pydantic fail with a ValidationError if template + # is not provided. + return values + + # Set some default values based on the field defaults + values.setdefault("template_format", "f-string") + values.setdefault("partial_variables", {}) + + if values.get("validate_template"): + if values["template_format"] == "mustache": + raise ValueError("Mustache templates cannot be validated.") + + if "input_variables" not in values: + raise ValueError( + "Input variables must be provided to validate the template." + ) + + all_inputs = values["input_variables"] + list(values["partial_variables"]) + check_valid_template( + values["template"], values["template_format"], all_inputs + ) + + if values["template_format"]: + values["input_variables"] = [ + var + for var in get_template_variables( + values["template"], values["template_format"] + ) + if var not in values["partial_variables"] + ] + + return values + def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: if self.template_format != "mustache": return super().get_input_schema(config) @@ -126,26 +163,6 @@ class PromptTemplate(StringPromptTemplate): kwargs = self._merge_partial_and_user_variables(**kwargs) return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs) - @root_validator() - def template_is_valid(cls, values: Dict) -> Dict: - """Check that template and input variables are consistent.""" - if values["validate_template"]: - if values["template_format"] == "mustache": - raise ValueError("Mustache templates cannot be validated.") - all_inputs = values["input_variables"] + list(values["partial_variables"]) - check_valid_template( - values["template"], values["template_format"], all_inputs - ) - elif values.get("template_format"): - values["input_variables"] = [ - var - for var in get_template_variables( - values["template"], values["template_format"] - ) - if var not in values["partial_variables"] - ] - return values - @classmethod def from_examples( cls,