Validate input_variables when using jinja2 templates (#3140)

`langchain.prompts.PromptTemplate` and
`langchain.prompts.FewShotPromptTemplate` do not validate
`input_variables` when initialized as `jinja2` template.

```python
# Using langchain v0.0.144
template = """"\
Your variable: {{ foo }}
{% if bar %}
You just set bar boolean variable to true
{% endif %}
"""

# Missing variable, should raise ValueError
prompt_template = PromptTemplate(template=template, 
                                 input_variables=["bar"], 
                                 template_format="jinja2", 
                                 validate_template=True)

# Extra variable, should raise ValueError
prompt_template = PromptTemplate(template=template, 
                                 input_variables=["bar", "foo", "extra", "thing"], 
                                 template_format="jinja2", 
                                 validate_template=True)
```
This commit is contained in:
engkheng
2023-04-20 07:18:32 +08:00
committed by GitHub
parent 3e0c44bae8
commit dbbc340f25
5 changed files with 167 additions and 20 deletions

View File

@@ -1,4 +1,6 @@
"""Test few shot prompt template."""
from typing import Dict, List, Tuple
import pytest
from langchain.prompts.few_shot import FewShotPromptTemplate
@@ -9,6 +11,25 @@ EXAMPLE_PROMPT = PromptTemplate(
)
@pytest.fixture()
def example_jinja2_prompt() -> Tuple[PromptTemplate, List[Dict[str, str]]]:
example_template = "{{ word }}: {{ antonym }}"
examples = [
{"word": "happy", "antonym": "sad"},
{"word": "tall", "antonym": "short"},
]
return (
PromptTemplate(
input_variables=["word", "antonym"],
template=example_template,
template_format="jinja2",
),
examples,
)
def test_suffix_only() -> None:
"""Test prompt works with just a suffix."""
suffix = "This is a {foo} test."
@@ -174,3 +195,71 @@ def test_partial() -> None:
"Now you try to talk about party."
)
assert output == expected_output
def test_prompt_jinja2_functionality(
example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]]
) -> None:
prefix = "Starting with {{ foo }}"
suffix = "Ending with {{ bar }}"
prompt = FewShotPromptTemplate(
input_variables=["foo", "bar"],
suffix=suffix,
prefix=prefix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
)
output = prompt.format(foo="hello", bar="bye")
expected_output = (
"Starting with hello\n\n" "happy: sad\n\n" "tall: short\n\n" "Ending with bye"
)
assert output == expected_output
def test_prompt_jinja2_missing_input_variables(
example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]]
) -> None:
"""Test error is raised when input variables are not provided."""
prefix = "Starting with {{ foo }}"
suffix = "Ending with {{ bar }}"
# Test when missing in suffix
with pytest.raises(ValueError):
FewShotPromptTemplate(
input_variables=[],
suffix=suffix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
)
# Test when missing in prefix
with pytest.raises(ValueError):
FewShotPromptTemplate(
input_variables=["bar"],
suffix=suffix,
prefix=prefix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
)
def test_prompt_jinja2_extra_input_variables(
example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]]
) -> None:
"""Test error is raised when there are too many input variables."""
prefix = "Starting with {{ foo }}"
suffix = "Ending with {{ bar }}"
with pytest.raises(ValueError):
FewShotPromptTemplate(
input_variables=["bar", "foo", "extra", "thing"],
suffix=suffix,
prefix=prefix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
)

View File

@@ -212,3 +212,33 @@ Your variable again: {{ foo }}
template_format="jinja2",
)
assert prompt == expected_prompt
def test_prompt_jinja2_missing_input_variables() -> None:
"""Test error is raised when input variables are not provided."""
template = "This is a {{ foo }} test."
input_variables: list = []
with pytest.raises(ValueError):
PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
)
def test_prompt_jinja2_extra_input_variables() -> None:
"""Test error is raised when there are too many input variables."""
template = "This is a {{ foo }} test."
input_variables = ["foo", "bar"]
with pytest.raises(ValueError):
PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
)
def test_prompt_jinja2_wrong_input_variables() -> None:
"""Test error is raised when name of input variable is wrong."""
template = "This is a {{ foo }} test."
input_variables = ["bar"]
with pytest.raises(ValueError):
PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
)