mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix(core): handle parent/child mustache vars (#33345)
**Description:**
currently `mustache_schema("{{x.y}} {{x}}")` will error. pr fixes
**Issue:** na
**Dependencies:**na
---------
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable, Sequence
|
||||||
from string import Formatter
|
from string import Formatter
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
@@ -149,9 +149,7 @@ def mustache_template_vars(
|
|||||||
Defs = dict[str, "Defs"]
|
Defs = dict[str, "Defs"]
|
||||||
|
|
||||||
|
|
||||||
def mustache_schema(
|
def mustache_schema(template: str) -> type[BaseModel]:
|
||||||
template: str,
|
|
||||||
) -> type[BaseModel]:
|
|
||||||
"""Get the variables from a mustache template.
|
"""Get the variables from a mustache template.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -175,6 +173,11 @@ def mustache_schema(
|
|||||||
fields[prefix] = False
|
fields[prefix] = False
|
||||||
elif type_ in {"variable", "no escape"}:
|
elif type_ in {"variable", "no escape"}:
|
||||||
fields[prefix + tuple(key.split("."))] = True
|
fields[prefix + tuple(key.split("."))] = True
|
||||||
|
|
||||||
|
for fkey, fval in fields.items():
|
||||||
|
fields[fkey] = fval and not any(
|
||||||
|
is_subsequence(fkey, k) for k in fields if k != fkey
|
||||||
|
)
|
||||||
defs: Defs = {} # None means leaf node
|
defs: Defs = {} # None means leaf node
|
||||||
while fields:
|
while fields:
|
||||||
field, is_leaf = fields.popitem()
|
field, is_leaf = fields.popitem()
|
||||||
@@ -327,3 +330,12 @@ class StringPromptTemplate(BasePromptTemplate, ABC):
|
|||||||
def pretty_print(self) -> None:
|
def pretty_print(self) -> None:
|
||||||
"""Print a pretty representation of the prompt."""
|
"""Print a pretty representation of the prompt."""
|
||||||
print(self.pretty_repr(html=is_interactive_env())) # noqa: T201
|
print(self.pretty_repr(html=is_interactive_env())) # noqa: T201
|
||||||
|
|
||||||
|
|
||||||
|
def is_subsequence(child: Sequence, parent: Sequence) -> bool:
|
||||||
|
"""Return True if child is subsequence of parent."""
|
||||||
|
if len(child) == 0 or len(parent) == 0:
|
||||||
|
return False
|
||||||
|
if len(parent) < len(child):
|
||||||
|
return False
|
||||||
|
return all(child[i] == parent[i] for i in range(len(child)))
|
||||||
|
|||||||
32
libs/core/tests/unit_tests/prompts/test_string.py
Normal file
32
libs/core/tests/unit_tests/prompts/test_string.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import pytest
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
from langchain_core.prompts.string import mustache_schema
|
||||||
|
from langchain_core.utils.pydantic import PYDANTIC_VERSION
|
||||||
|
|
||||||
|
PYDANTIC_VERSION_AT_LEAST_29 = version.parse("2.9") <= PYDANTIC_VERSION
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not PYDANTIC_VERSION_AT_LEAST_29,
|
||||||
|
reason=(
|
||||||
|
"Only test with most recent version of pydantic. "
|
||||||
|
"Pydantic introduced small fixes to generated JSONSchema on minor versions."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_mustache_schema_parent_child() -> None:
|
||||||
|
template = "{{x.y}} {{x}}"
|
||||||
|
expected = {
|
||||||
|
"$defs": {
|
||||||
|
"x": {
|
||||||
|
"properties": {"y": {"default": None, "title": "Y", "type": "string"}},
|
||||||
|
"title": "x",
|
||||||
|
"type": "object",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"properties": {"x": {"$ref": "#/$defs/x", "default": None}},
|
||||||
|
"title": "PromptInput",
|
||||||
|
"type": "object",
|
||||||
|
}
|
||||||
|
actual = mustache_schema(template).model_json_schema()
|
||||||
|
assert expected == actual
|
||||||
Reference in New Issue
Block a user