From e9ef08862df79f01a5723c8e155c544f670dff36 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 2 Feb 2023 22:08:01 -0800 Subject: [PATCH] validate template (#865) --- langchain/prompts/few_shot.py | 14 +++++++++----- langchain/prompts/prompt.py | 10 +++++++--- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/langchain/prompts/few_shot.py b/langchain/prompts/few_shot.py index 71b4df37988..5a7b0c54fbd 100644 --- a/langchain/prompts/few_shot.py +++ b/langchain/prompts/few_shot.py @@ -41,6 +41,9 @@ class FewShotPromptTemplate(BasePromptTemplate, BaseModel): template_format: str = "f-string" """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" + validate_template: bool = True + """Whether or not to try validating the template.""" + @root_validator(pre=True) def check_examples_and_selector(cls, values: Dict) -> Dict: """Check that one and only one of examples/example_selector are provided.""" @@ -61,11 +64,12 @@ class FewShotPromptTemplate(BasePromptTemplate, BaseModel): @root_validator() def template_is_valid(cls, values: Dict) -> Dict: """Check that prefix, suffix and input variables are consistent.""" - check_valid_template( - values["prefix"] + values["suffix"], - values["template_format"], - values["input_variables"], - ) + if values["validate_template"]: + check_valid_template( + values["prefix"] + values["suffix"], + values["template_format"], + values["input_variables"], + ) return values class Config: diff --git a/langchain/prompts/prompt.py b/langchain/prompts/prompt.py index fe9a798c5ed..eed015f51cb 100644 --- a/langchain/prompts/prompt.py +++ b/langchain/prompts/prompt.py @@ -31,6 +31,9 @@ class PromptTemplate(BasePromptTemplate, BaseModel): template_format: str = "f-string" """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" + validate_template: bool = True + """Whether or not to try validating the template.""" + @property def _prompt_type(self) -> str: """Return the prompt type key.""" @@ -61,9 +64,10 @@ class PromptTemplate(BasePromptTemplate, BaseModel): @root_validator() def template_is_valid(cls, values: Dict) -> Dict: """Check that template and input variables are consistent.""" - check_valid_template( - values["template"], values["template_format"], values["input_variables"] - ) + if values["validate_template"]: + check_valid_template( + values["template"], values["template_format"], values["input_variables"] + ) return values @classmethod