mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 05:13:46 +00:00
Include "no escape" and "inverted section" mustache vars in Prompt.input_variables and Prompt.input_schema (#22981)
This commit is contained in:
parent
7a0b36501f
commit
f01f12ce1e
@ -103,9 +103,12 @@ def mustache_template_vars(
|
||||
in_section = False
|
||||
elif in_section:
|
||||
continue
|
||||
elif type in ("variable", "section") and key != ".":
|
||||
elif (
|
||||
type in ("variable", "section", "inverted section", "no escape")
|
||||
and key != "."
|
||||
):
|
||||
vars.add(key.split(".")[0])
|
||||
if type == "section":
|
||||
if type in ("section", "inverted section"):
|
||||
in_section = True
|
||||
return vars
|
||||
|
||||
@ -117,24 +120,25 @@ def mustache_schema(
|
||||
template: str,
|
||||
) -> Type[BaseModel]:
|
||||
"""Get the variables from a mustache template."""
|
||||
fields = set()
|
||||
fields = {}
|
||||
prefix: Tuple[str, ...] = ()
|
||||
for type, key in mustache.tokenize(template):
|
||||
if key == ".":
|
||||
continue
|
||||
if type == "end":
|
||||
prefix = prefix[: -key.count(".")]
|
||||
elif type == "section":
|
||||
elif type in ("section", "inverted section"):
|
||||
prefix = prefix + tuple(key.split("."))
|
||||
elif type == "variable":
|
||||
fields.add(prefix + tuple(key.split(".")))
|
||||
fields[prefix] = False
|
||||
elif type in ("variable", "no escape"):
|
||||
fields[prefix + tuple(key.split("."))] = True
|
||||
defs: Defs = {} # None means leaf node
|
||||
while fields:
|
||||
field = fields.pop()
|
||||
field, is_leaf = fields.popitem()
|
||||
current = defs
|
||||
for part in field[:-1]:
|
||||
current = current.setdefault(part, {})
|
||||
current[field[-1]] = {}
|
||||
current.setdefault(field[-1], "" if is_leaf else {}) # type: ignore[arg-type]
|
||||
return _create_model_recursive("PromptInput", defs)
|
||||
|
||||
|
||||
@ -142,7 +146,7 @@ def _create_model_recursive(name: str, defs: Defs) -> Type:
|
||||
return create_model( # type: ignore[call-overload]
|
||||
name,
|
||||
**{
|
||||
k: (_create_model_recursive(k, v), None) if v else (str, None)
|
||||
k: (_create_model_recursive(k, v), None) if v else (type(v), None)
|
||||
for k, v in defs.items()
|
||||
},
|
||||
)
|
||||
|
@ -67,7 +67,7 @@ def test_mustache_prompt_from_template() -> None:
|
||||
}
|
||||
|
||||
# Multiple input variables with repeats.
|
||||
template = "This {{bar}} is a {{foo}} test {{foo}}."
|
||||
template = "This {{bar}} is a {{foo}} test {{&foo}}."
|
||||
prompt = PromptTemplate.from_template(template, template_format="mustache")
|
||||
assert prompt.format(bar="baz", foo="bar") == "This baz is a bar test bar."
|
||||
assert prompt.input_variables == ["bar", "foo"]
|
||||
@ -81,7 +81,7 @@ def test_mustache_prompt_from_template() -> None:
|
||||
}
|
||||
|
||||
# Nested variables.
|
||||
template = "This {{obj.bar}} is a {{obj.foo}} test {{foo}}."
|
||||
template = "This {{obj.bar}} is a {{obj.foo}} test {{{foo}}}."
|
||||
prompt = PromptTemplate.from_template(template, template_format="mustache")
|
||||
assert prompt.format(obj={"bar": "foo", "foo": "bar"}, foo="baz") == (
|
||||
"This foo is a bar test baz."
|
||||
@ -167,6 +167,22 @@ def test_mustache_prompt_from_template() -> None:
|
||||
},
|
||||
}
|
||||
|
||||
template = """This{{^foo}}
|
||||
no foos
|
||||
{{/foo}}is a test."""
|
||||
prompt = PromptTemplate.from_template(template, template_format="mustache")
|
||||
assert prompt.format() == (
|
||||
"""This
|
||||
no foos
|
||||
is a test."""
|
||||
)
|
||||
assert prompt.input_variables == ["foo"]
|
||||
assert prompt.input_schema.schema() == {
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
"properties": {"foo": {"title": "Foo", "type": "object"}},
|
||||
}
|
||||
|
||||
|
||||
def test_prompt_from_template_with_partial_variables() -> None:
|
||||
"""Test prompts can be constructed from a template with partial variables."""
|
||||
|
Loading…
Reference in New Issue
Block a user