validate template (#865)

This commit is contained in:
Harrison Chase 2023-02-02 22:08:01 -08:00 committed by GitHub
parent 364b771743
commit e9ef08862d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 8 deletions

View File

@ -41,6 +41,9 @@ class FewShotPromptTemplate(BasePromptTemplate, BaseModel):
template_format: str = "f-string" template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'.""" """The format of the prompt template. Options are: 'f-string', 'jinja2'."""
validate_template: bool = True
"""Whether or not to try validating the template."""
@root_validator(pre=True) @root_validator(pre=True)
def check_examples_and_selector(cls, values: Dict) -> Dict: def check_examples_and_selector(cls, values: Dict) -> Dict:
"""Check that one and only one of examples/example_selector are provided.""" """Check that one and only one of examples/example_selector are provided."""
@ -61,11 +64,12 @@ class FewShotPromptTemplate(BasePromptTemplate, BaseModel):
@root_validator() @root_validator()
def template_is_valid(cls, values: Dict) -> Dict: def template_is_valid(cls, values: Dict) -> Dict:
"""Check that prefix, suffix and input variables are consistent.""" """Check that prefix, suffix and input variables are consistent."""
check_valid_template( if values["validate_template"]:
values["prefix"] + values["suffix"], check_valid_template(
values["template_format"], values["prefix"] + values["suffix"],
values["input_variables"], values["template_format"],
) values["input_variables"],
)
return values return values
class Config: class Config:

View File

@ -31,6 +31,9 @@ class PromptTemplate(BasePromptTemplate, BaseModel):
template_format: str = "f-string" template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'.""" """The format of the prompt template. Options are: 'f-string', 'jinja2'."""
validate_template: bool = True
"""Whether or not to try validating the template."""
@property @property
def _prompt_type(self) -> str: def _prompt_type(self) -> str:
"""Return the prompt type key.""" """Return the prompt type key."""
@ -61,9 +64,10 @@ class PromptTemplate(BasePromptTemplate, BaseModel):
@root_validator() @root_validator()
def template_is_valid(cls, values: Dict) -> Dict: def template_is_valid(cls, values: Dict) -> Dict:
"""Check that template and input variables are consistent.""" """Check that template and input variables are consistent."""
check_valid_template( if values["validate_template"]:
values["template"], values["template_format"], values["input_variables"] check_valid_template(
) values["template"], values["template_format"], values["input_variables"]
)
return values return values
@classmethod @classmethod