mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 12:18:24 +00:00
core[patch]: Update remaining root_validators (#22829)
This PR updates the remaining root_validators in core to either be explicit pre-init or post-init validators.
This commit is contained in:
parent
265e650e64
commit
5dbbdcbf8e
@ -120,7 +120,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||||
"""[DEPRECATED] Callback manager to add to the run trace."""
|
"""[DEPRECATED] Callback manager to add to the run trace."""
|
||||||
|
|
||||||
@root_validator()
|
@root_validator(pre=True)
|
||||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||||
"""Raise deprecation warning if callback_manager is used."""
|
"""Raise deprecation warning if callback_manager is used."""
|
||||||
if values.get("callback_manager") is not None:
|
if values.get("callback_manager") is not None:
|
||||||
|
@ -232,7 +232,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@root_validator()
|
@root_validator(pre=True)
|
||||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||||
"""Raise deprecation warning if callback_manager is used."""
|
"""Raise deprecation warning if callback_manager is used."""
|
||||||
if values.get("callback_manager") is not None:
|
if values.get("callback_manager") is not None:
|
||||||
|
@ -149,7 +149,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
|||||||
"invalid_tool_calls": self.invalid_tool_calls,
|
"invalid_tool_calls": self.invalid_tool_calls,
|
||||||
}
|
}
|
||||||
|
|
||||||
@root_validator()
|
@root_validator(pre=False, skip_on_failure=True)
|
||||||
def init_tool_calls(cls, values: dict) -> dict:
|
def init_tool_calls(cls, values: dict) -> dict:
|
||||||
if not values["tool_call_chunks"]:
|
if not values["tool_call_chunks"]:
|
||||||
values["tool_calls"] = []
|
values["tool_calls"] = []
|
||||||
|
@ -121,7 +121,7 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
|
|||||||
template_format: Literal["f-string", "jinja2"] = "f-string"
|
template_format: Literal["f-string", "jinja2"] = "f-string"
|
||||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||||
|
|
||||||
@root_validator()
|
@root_validator(pre=False, skip_on_failure=True)
|
||||||
def template_is_valid(cls, values: Dict) -> Dict:
|
def template_is_valid(cls, values: Dict) -> Dict:
|
||||||
"""Check that prefix, suffix, and input variables are consistent."""
|
"""Check that prefix, suffix, and input variables are consistent."""
|
||||||
if values["validate_template"]:
|
if values["validate_template"]:
|
||||||
|
@ -64,7 +64,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
|||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@root_validator()
|
@root_validator(pre=False, skip_on_failure=True)
|
||||||
def template_is_valid(cls, values: Dict) -> Dict:
|
def template_is_valid(cls, values: Dict) -> Dict:
|
||||||
"""Check that prefix, suffix, and input variables are consistent."""
|
"""Check that prefix, suffix, and input variables are consistent."""
|
||||||
if values["validate_template"]:
|
if values["validate_template"]:
|
||||||
|
@ -47,7 +47,7 @@ class PromptTemplate(StringPromptTemplate):
|
|||||||
prompt.format(foo="bar")
|
prompt.format(foo="bar")
|
||||||
|
|
||||||
# Instantiation using initializer
|
# Instantiation using initializer
|
||||||
prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}")
|
prompt = PromptTemplate(template="Say {foo}")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -74,6 +74,43 @@ class PromptTemplate(StringPromptTemplate):
|
|||||||
validate_template: bool = False
|
validate_template: bool = False
|
||||||
"""Whether or not to try validating the template."""
|
"""Whether or not to try validating the template."""
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def pre_init_validation(cls, values: Dict) -> Dict:
|
||||||
|
"""Check that template and input variables are consistent."""
|
||||||
|
if values.get("template") is None:
|
||||||
|
# Will let pydantic fail with a ValidationError if template
|
||||||
|
# is not provided.
|
||||||
|
return values
|
||||||
|
|
||||||
|
# Set some default values based on the field defaults
|
||||||
|
values.setdefault("template_format", "f-string")
|
||||||
|
values.setdefault("partial_variables", {})
|
||||||
|
|
||||||
|
if values.get("validate_template"):
|
||||||
|
if values["template_format"] == "mustache":
|
||||||
|
raise ValueError("Mustache templates cannot be validated.")
|
||||||
|
|
||||||
|
if "input_variables" not in values:
|
||||||
|
raise ValueError(
|
||||||
|
"Input variables must be provided to validate the template."
|
||||||
|
)
|
||||||
|
|
||||||
|
all_inputs = values["input_variables"] + list(values["partial_variables"])
|
||||||
|
check_valid_template(
|
||||||
|
values["template"], values["template_format"], all_inputs
|
||||||
|
)
|
||||||
|
|
||||||
|
if values["template_format"]:
|
||||||
|
values["input_variables"] = [
|
||||||
|
var
|
||||||
|
for var in get_template_variables(
|
||||||
|
values["template"], values["template_format"]
|
||||||
|
)
|
||||||
|
if var not in values["partial_variables"]
|
||||||
|
]
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
|
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
|
||||||
if self.template_format != "mustache":
|
if self.template_format != "mustache":
|
||||||
return super().get_input_schema(config)
|
return super().get_input_schema(config)
|
||||||
@ -126,26 +163,6 @@ class PromptTemplate(StringPromptTemplate):
|
|||||||
kwargs = self._merge_partial_and_user_variables(**kwargs)
|
kwargs = self._merge_partial_and_user_variables(**kwargs)
|
||||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
|
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
|
||||||
|
|
||||||
@root_validator()
|
|
||||||
def template_is_valid(cls, values: Dict) -> Dict:
|
|
||||||
"""Check that template and input variables are consistent."""
|
|
||||||
if values["validate_template"]:
|
|
||||||
if values["template_format"] == "mustache":
|
|
||||||
raise ValueError("Mustache templates cannot be validated.")
|
|
||||||
all_inputs = values["input_variables"] + list(values["partial_variables"])
|
|
||||||
check_valid_template(
|
|
||||||
values["template"], values["template_format"], all_inputs
|
|
||||||
)
|
|
||||||
elif values.get("template_format"):
|
|
||||||
values["input_variables"] = [
|
|
||||||
var
|
|
||||||
for var in get_template_variables(
|
|
||||||
values["template"], values["template_format"]
|
|
||||||
)
|
|
||||||
if var not in values["partial_variables"]
|
|
||||||
]
|
|
||||||
return values
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_examples(
|
def from_examples(
|
||||||
cls,
|
cls,
|
||||||
|
Loading…
Reference in New Issue
Block a user