mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-11 07:50:47 +00:00
Validate input_variables
when using jinja2
templates (#3140)
`langchain.prompts.PromptTemplate` and `langchain.prompts.FewShotPromptTemplate` do not validate `input_variables` when initialized as `jinja2` template. ```python # Using langchain v0.0.144 template = """"\ Your variable: {{ foo }} {% if bar %} You just set bar boolean variable to true {% endif %} """ # Missing variable, should raise ValueError prompt_template = PromptTemplate(template=template, input_variables=["bar"], template_format="jinja2", validate_template=True) # Extra variable, should raise ValueError prompt_template = PromptTemplate(template=template, input_variables=["bar", "foo", "extra", "thing"], template_format="jinja2", validate_template=True) ```
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
"""Test few shot prompt template."""
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
@@ -9,6 +11,25 @@ EXAMPLE_PROMPT = PromptTemplate(
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def example_jinja2_prompt() -> Tuple[PromptTemplate, List[Dict[str, str]]]:
|
||||
example_template = "{{ word }}: {{ antonym }}"
|
||||
|
||||
examples = [
|
||||
{"word": "happy", "antonym": "sad"},
|
||||
{"word": "tall", "antonym": "short"},
|
||||
]
|
||||
|
||||
return (
|
||||
PromptTemplate(
|
||||
input_variables=["word", "antonym"],
|
||||
template=example_template,
|
||||
template_format="jinja2",
|
||||
),
|
||||
examples,
|
||||
)
|
||||
|
||||
|
||||
def test_suffix_only() -> None:
|
||||
"""Test prompt works with just a suffix."""
|
||||
suffix = "This is a {foo} test."
|
||||
@@ -174,3 +195,71 @@ def test_partial() -> None:
|
||||
"Now you try to talk about party."
|
||||
)
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_prompt_jinja2_functionality(
|
||||
example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]]
|
||||
) -> None:
|
||||
prefix = "Starting with {{ foo }}"
|
||||
suffix = "Ending with {{ bar }}"
|
||||
|
||||
prompt = FewShotPromptTemplate(
|
||||
input_variables=["foo", "bar"],
|
||||
suffix=suffix,
|
||||
prefix=prefix,
|
||||
examples=example_jinja2_prompt[1],
|
||||
example_prompt=example_jinja2_prompt[0],
|
||||
template_format="jinja2",
|
||||
)
|
||||
output = prompt.format(foo="hello", bar="bye")
|
||||
expected_output = (
|
||||
"Starting with hello\n\n" "happy: sad\n\n" "tall: short\n\n" "Ending with bye"
|
||||
)
|
||||
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_prompt_jinja2_missing_input_variables(
|
||||
example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]]
|
||||
) -> None:
|
||||
"""Test error is raised when input variables are not provided."""
|
||||
prefix = "Starting with {{ foo }}"
|
||||
suffix = "Ending with {{ bar }}"
|
||||
|
||||
# Test when missing in suffix
|
||||
with pytest.raises(ValueError):
|
||||
FewShotPromptTemplate(
|
||||
input_variables=[],
|
||||
suffix=suffix,
|
||||
examples=example_jinja2_prompt[1],
|
||||
example_prompt=example_jinja2_prompt[0],
|
||||
template_format="jinja2",
|
||||
)
|
||||
|
||||
# Test when missing in prefix
|
||||
with pytest.raises(ValueError):
|
||||
FewShotPromptTemplate(
|
||||
input_variables=["bar"],
|
||||
suffix=suffix,
|
||||
prefix=prefix,
|
||||
examples=example_jinja2_prompt[1],
|
||||
example_prompt=example_jinja2_prompt[0],
|
||||
template_format="jinja2",
|
||||
)
|
||||
|
||||
|
||||
def test_prompt_jinja2_extra_input_variables(
|
||||
example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]]
|
||||
) -> None:
|
||||
"""Test error is raised when there are too many input variables."""
|
||||
prefix = "Starting with {{ foo }}"
|
||||
suffix = "Ending with {{ bar }}"
|
||||
with pytest.raises(ValueError):
|
||||
FewShotPromptTemplate(
|
||||
input_variables=["bar", "foo", "extra", "thing"],
|
||||
suffix=suffix,
|
||||
prefix=prefix,
|
||||
examples=example_jinja2_prompt[1],
|
||||
example_prompt=example_jinja2_prompt[0],
|
||||
template_format="jinja2",
|
||||
)
|
||||
|
@@ -212,3 +212,33 @@ Your variable again: {{ foo }}
|
||||
template_format="jinja2",
|
||||
)
|
||||
assert prompt == expected_prompt
|
||||
|
||||
|
||||
def test_prompt_jinja2_missing_input_variables() -> None:
|
||||
"""Test error is raised when input variables are not provided."""
|
||||
template = "This is a {{ foo }} test."
|
||||
input_variables: list = []
|
||||
with pytest.raises(ValueError):
|
||||
PromptTemplate(
|
||||
input_variables=input_variables, template=template, template_format="jinja2"
|
||||
)
|
||||
|
||||
|
||||
def test_prompt_jinja2_extra_input_variables() -> None:
|
||||
"""Test error is raised when there are too many input variables."""
|
||||
template = "This is a {{ foo }} test."
|
||||
input_variables = ["foo", "bar"]
|
||||
with pytest.raises(ValueError):
|
||||
PromptTemplate(
|
||||
input_variables=input_variables, template=template, template_format="jinja2"
|
||||
)
|
||||
|
||||
|
||||
def test_prompt_jinja2_wrong_input_variables() -> None:
|
||||
"""Test error is raised when name of input variable is wrong."""
|
||||
template = "This is a {{ foo }} test."
|
||||
input_variables = ["bar"]
|
||||
with pytest.raises(ValueError):
|
||||
PromptTemplate(
|
||||
input_variables=input_variables, template=template, template_format="jinja2"
|
||||
)
|
||||
|
Reference in New Issue
Block a user