mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 04:07:54 +00:00
core[patch]: Update remaining root_validators (#22829)
This PR updates the remaining root_validators in core to either be explicit pre-init or post-init validators.
This commit is contained in:
parent
265e650e64
commit
5dbbdcbf8e
@ -120,7 +120,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
"""[DEPRECATED] Callback manager to add to the run trace."""
|
||||
|
||||
@root_validator()
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
|
@ -232,7 +232,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator()
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
|
@ -149,7 +149,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
"invalid_tool_calls": self.invalid_tool_calls,
|
||||
}
|
||||
|
||||
@root_validator()
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def init_tool_calls(cls, values: dict) -> dict:
|
||||
if not values["tool_call_chunks"]:
|
||||
values["tool_calls"] = []
|
||||
|
@ -121,7 +121,7 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
|
||||
template_format: Literal["f-string", "jinja2"] = "f-string"
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
|
||||
@root_validator()
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def template_is_valid(cls, values: Dict) -> Dict:
|
||||
"""Check that prefix, suffix, and input variables are consistent."""
|
||||
if values["validate_template"]:
|
||||
|
@ -64,7 +64,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def template_is_valid(cls, values: Dict) -> Dict:
|
||||
"""Check that prefix, suffix, and input variables are consistent."""
|
||||
if values["validate_template"]:
|
||||
|
@ -47,7 +47,7 @@ class PromptTemplate(StringPromptTemplate):
|
||||
prompt.format(foo="bar")
|
||||
|
||||
# Instantiation using initializer
|
||||
prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}")
|
||||
prompt = PromptTemplate(template="Say {foo}")
|
||||
"""
|
||||
|
||||
@property
|
||||
@ -74,6 +74,43 @@ class PromptTemplate(StringPromptTemplate):
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def pre_init_validation(cls, values: Dict) -> Dict:
|
||||
"""Check that template and input variables are consistent."""
|
||||
if values.get("template") is None:
|
||||
# Will let pydantic fail with a ValidationError if template
|
||||
# is not provided.
|
||||
return values
|
||||
|
||||
# Set some default values based on the field defaults
|
||||
values.setdefault("template_format", "f-string")
|
||||
values.setdefault("partial_variables", {})
|
||||
|
||||
if values.get("validate_template"):
|
||||
if values["template_format"] == "mustache":
|
||||
raise ValueError("Mustache templates cannot be validated.")
|
||||
|
||||
if "input_variables" not in values:
|
||||
raise ValueError(
|
||||
"Input variables must be provided to validate the template."
|
||||
)
|
||||
|
||||
all_inputs = values["input_variables"] + list(values["partial_variables"])
|
||||
check_valid_template(
|
||||
values["template"], values["template_format"], all_inputs
|
||||
)
|
||||
|
||||
if values["template_format"]:
|
||||
values["input_variables"] = [
|
||||
var
|
||||
for var in get_template_variables(
|
||||
values["template"], values["template_format"]
|
||||
)
|
||||
if var not in values["partial_variables"]
|
||||
]
|
||||
|
||||
return values
|
||||
|
||||
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
|
||||
if self.template_format != "mustache":
|
||||
return super().get_input_schema(config)
|
||||
@ -126,26 +163,6 @@ class PromptTemplate(StringPromptTemplate):
|
||||
kwargs = self._merge_partial_and_user_variables(**kwargs)
|
||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
|
||||
|
||||
@root_validator()
|
||||
def template_is_valid(cls, values: Dict) -> Dict:
|
||||
"""Check that template and input variables are consistent."""
|
||||
if values["validate_template"]:
|
||||
if values["template_format"] == "mustache":
|
||||
raise ValueError("Mustache templates cannot be validated.")
|
||||
all_inputs = values["input_variables"] + list(values["partial_variables"])
|
||||
check_valid_template(
|
||||
values["template"], values["template_format"], all_inputs
|
||||
)
|
||||
elif values.get("template_format"):
|
||||
values["input_variables"] = [
|
||||
var
|
||||
for var in get_template_variables(
|
||||
values["template"], values["template_format"]
|
||||
)
|
||||
if var not in values["partial_variables"]
|
||||
]
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def from_examples(
|
||||
cls,
|
||||
|
Loading…
Reference in New Issue
Block a user