diff --git a/libs/core/langchain_core/prompts/string.py b/libs/core/langchain_core/prompts/string.py index 0d8c8ce1171..5f5104a20ba 100644 --- a/libs/core/langchain_core/prompts/string.py +++ b/libs/core/langchain_core/prompts/string.py @@ -97,19 +97,18 @@ def mustache_template_vars( ) -> Set[str]: """Get the variables from a mustache template.""" vars: Set[str] = set() - in_section = False + section_depth = 0 for type, key in mustache.tokenize(template): if type == "end": - in_section = False - elif in_section: - continue + section_depth -= 1 elif ( type in ("variable", "section", "inverted section", "no escape") and key != "." + and section_depth == 0 ): vars.add(key.split(".")[0]) - if type in ("section", "inverted section"): - in_section = True + if type in ("section", "inverted section"): + section_depth += 1 return vars @@ -122,12 +121,15 @@ def mustache_schema( """Get the variables from a mustache template.""" fields = {} prefix: Tuple[str, ...] = () + section_stack: List[Tuple[str, ...]] = [] for type, key in mustache.tokenize(template): if key == ".": continue if type == "end": - prefix = prefix[: -key.count(".")] + if section_stack: + prefix = section_stack.pop() elif type in ("section", "inverted section"): + section_stack.append(prefix) prefix = prefix + tuple(key.split(".")) fields[prefix] = False elif type in ("variable", "no escape"): diff --git a/libs/core/tests/unit_tests/prompts/test_prompt.py b/libs/core/tests/unit_tests/prompts/test_prompt.py index 4c8423b7ee6..ecb30b2c79c 100644 --- a/libs/core/tests/unit_tests/prompts/test_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_prompt.py @@ -141,6 +141,115 @@ def test_mustache_prompt_from_template() -> None: }, } + # more complex nested section/context variables + template = """This{{#foo}} + {{bar}} + {{#baz}} + {{qux}} + {{/baz}} + {{quux}} + {{/foo}}is a test.""" + prompt = PromptTemplate.from_template(template, template_format="mustache") + assert prompt.format( + foo={"bar": "yo", "baz": [{"qux": "wassup"}], "quux": "hello"} + ) == ( + """This + yo + wassup + hello + is a test.""" + ) + assert prompt.input_variables == ["foo"] + assert prompt.input_schema.schema() == { + "title": "PromptInput", + "type": "object", + "properties": {"foo": {"$ref": "#/definitions/foo"}}, + "definitions": { + "foo": { + "title": "foo", + "type": "object", + "properties": { + "bar": {"title": "Bar", "type": "string"}, + "baz": {"$ref": "#/definitions/baz"}, + "quux": {"title": "Quux", "type": "string"}, + }, + }, + "baz": { + "title": "baz", + "type": "object", + "properties": {"qux": {"title": "Qux", "type": "string"}}, + }, + }, + } + + # triply nested section/context variables + template = """This{{#foo}} + {{bar}} + {{#baz.qux}} + {{#barfoo}} + {{foobar}} + {{/barfoo}} + {{foobar}} + {{/baz.qux}} + {{quux}} + {{/foo}}is a test.""" + prompt = PromptTemplate.from_template(template, template_format="mustache") + assert prompt.format( + foo={ + "bar": "yo", + "baz": { + "qux": [ + {"foobar": "wassup"}, + {"foobar": "yoyo", "barfoo": {"foobar": "hello there"}}, + ] + }, + "quux": "hello", + } + ) == ( + """This + yo + wassup + hello there + yoyo + hello + is a test.""" + ) + assert prompt.input_variables == ["foo"] + assert prompt.input_schema.schema() == { + "title": "PromptInput", + "type": "object", + "properties": {"foo": {"$ref": "#/definitions/foo"}}, + "definitions": { + "foo": { + "title": "foo", + "type": "object", + "properties": { + "bar": {"title": "Bar", "type": "string"}, + "baz": {"$ref": "#/definitions/baz"}, + "quux": {"title": "Quux", "type": "string"}, + }, + }, + "baz": { + "title": "baz", + "type": "object", + "properties": {"qux": {"$ref": "#/definitions/qux"}}, + }, + "qux": { + "title": "qux", + "type": "object", + "properties": { + "foobar": {"title": "Foobar", "type": "string"}, + "barfoo": {"$ref": "#/definitions/barfoo"}, + }, + }, + "barfoo": { + "title": "barfoo", + "type": "object", + "properties": {"foobar": {"title": "Foobar", "type": "string"}}, + }, + }, + } + # section/context variables with repeats template = """This{{#foo}} {{bar}}