mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-03 19:57:51 +00:00
Support inference of input_variables
from jinja2
template (#3013)
`langchain.prompts.PromptTemplate` is unable to infer `input_variables` from jinja2 template. ```python # Using langchain v0.0.141 template_string = """\ Hello world Your variable: {{ var }} {# This will not get rendered #} {% if verbose %} Congrats! You just turned on verbose mode and got extra messages! {% endif %} """ template = PromptTemplate.from_template(template_string, template_format="jinja2") print(template.input_variables) # Output ['# This will not get rendered #', '% endif %', '% if verbose %'] ``` --------- Co-authored-by: engkheng <ongengkheng929@example.com>
This commit is contained in:
parent
dac32c59e5
commit
19febc77d6
@ -3,8 +3,9 @@ from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from string import Formatter
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Dict, List, Set, Union
|
||||
|
||||
from jinja2 import Environment, meta
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.prompts.base import (
|
||||
@ -14,6 +15,13 @@ from langchain.prompts.base import (
|
||||
)
|
||||
|
||||
|
||||
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
|
||||
env = Environment()
|
||||
ast = env.parse(template)
|
||||
variables = meta.find_undeclared_variables(ast)
|
||||
return variables
|
||||
|
||||
|
||||
class PromptTemplate(StringPromptTemplate):
|
||||
"""Schema to represent a prompt for an LLM.
|
||||
|
||||
@ -125,9 +133,15 @@ class PromptTemplate(StringPromptTemplate):
|
||||
@classmethod
|
||||
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
|
||||
"""Load a prompt template from a template."""
|
||||
input_variables = {
|
||||
v for _, v, _, _ in Formatter().parse(template) if v is not None
|
||||
}
|
||||
if "template_format" in kwargs and kwargs["template_format"] == "jinja2":
|
||||
# Get the variables for the template
|
||||
input_variables = _get_jinja2_variables_from_template(template)
|
||||
|
||||
else:
|
||||
input_variables = {
|
||||
v for _, v, _, _ in Formatter().parse(template) if v is not None
|
||||
}
|
||||
|
||||
return cls(
|
||||
input_variables=list(sorted(input_variables)), template=template, **kwargs
|
||||
)
|
||||
|
@ -145,3 +145,70 @@ def test_partial() -> None:
|
||||
assert new_result == "This is a 3 test."
|
||||
result = prompt.format(foo="foo")
|
||||
assert result == "This is a foo test."
|
||||
|
||||
|
||||
def test_prompt_from_jinja2_template() -> None:
|
||||
"""Test prompts can be constructed from a jinja2 template."""
|
||||
# Empty input variable.
|
||||
template = """Hello there
|
||||
There is no variable here {
|
||||
Will it get confused{ }?
|
||||
"""
|
||||
prompt = PromptTemplate.from_template(template, template_format="jinja2")
|
||||
expected_prompt = PromptTemplate(
|
||||
template=template, input_variables=[], template_format="jinja2"
|
||||
)
|
||||
assert prompt == expected_prompt
|
||||
|
||||
# Multiple input variables.
|
||||
template = """\
|
||||
Hello world
|
||||
|
||||
Your variable: {{ foo }}
|
||||
|
||||
{# This will not get rendered #}
|
||||
|
||||
{% if bar %}
|
||||
You just set bar boolean variable to true
|
||||
{% endif %}
|
||||
|
||||
{% for i in foo_list %}
|
||||
{{ i }}
|
||||
{% endfor %}
|
||||
"""
|
||||
prompt = PromptTemplate.from_template(template, template_format="jinja2")
|
||||
expected_prompt = PromptTemplate(
|
||||
template=template,
|
||||
input_variables=["bar", "foo", "foo_list"],
|
||||
template_format="jinja2",
|
||||
)
|
||||
|
||||
assert prompt == expected_prompt
|
||||
|
||||
# Multiple input variables with repeats.
|
||||
template = """\
|
||||
Hello world
|
||||
|
||||
Your variable: {{ foo }}
|
||||
|
||||
{# This will not get rendered #}
|
||||
|
||||
{% if bar %}
|
||||
You just set bar boolean variable to true
|
||||
{% endif %}
|
||||
|
||||
{% for i in foo_list %}
|
||||
{{ i }}
|
||||
{% endfor %}
|
||||
|
||||
{% if bar %}
|
||||
Your variable again: {{ foo }}
|
||||
{% endif %}
|
||||
"""
|
||||
prompt = PromptTemplate.from_template(template, template_format="jinja2")
|
||||
expected_prompt = PromptTemplate(
|
||||
template=template,
|
||||
input_variables=["bar", "foo", "foo_list"],
|
||||
template_format="jinja2",
|
||||
)
|
||||
assert prompt == expected_prompt
|
||||
|
Loading…
Reference in New Issue
Block a user