fix(core): sanitize prompts more (#36613)

add more sanitization to prompts
This commit is contained in:
Eugene Yurtsev
2026-04-08 14:32:25 -04:00
committed by GitHub
parent 4fe869f1c3
commit 6bab0ba3c1
7 changed files with 228 additions and 31 deletions

View File

@@ -1,9 +1,12 @@
"""Dict prompt template."""
from __future__ import annotations
import warnings
from functools import cached_property
from typing import Any, Literal, Optional
from pydantic import model_validator
from typing_extensions import override
from langchain_core.load import dumpd
@@ -25,6 +28,12 @@ class DictPromptTemplate(RunnableSerializable[dict, dict]):
template: dict[str, Any]
template_format: Literal["f-string", "mustache"]
@model_validator(mode="after")
def validate_template(self) -> DictPromptTemplate:
"""Validate that the template structure contains only safe variables."""
_get_input_variables(self.template, self.template_format)
return self
@property
def input_variables(self) -> list[str]:
"""Template input variables."""

View File

@@ -2,13 +2,15 @@
from typing import Any
from pydantic import Field
from pydantic import Field, model_validator
from typing_extensions import Self
from langchain_core.prompt_values import ImagePromptValue, ImageURL, PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,
PromptTemplateFormat,
get_template_variables,
)
from langchain_core.runnables import run_in_executor
@@ -40,8 +42,17 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
f" Found: {overlap}"
)
raise ValueError(msg)
super().__init__(**kwargs)
@model_validator(mode="after")
def validate_template(self) -> Self:
"""Validate template string values after Pydantic parsing."""
for value in self.template.values():
if isinstance(value, str):
get_template_variables(value, self.template_format)
return self
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""

View File

@@ -263,6 +263,46 @@ DEFAULT_VALIDATOR_MAPPING: dict[str, Callable] = {
}
def _parse_f_string_fields(template: str) -> list[tuple[str, str | None]]:
fields: list[tuple[str, str | None]] = []
for _, field_name, format_spec, _ in Formatter().parse(template):
if field_name is not None:
fields.append((field_name, format_spec))
return fields
def validate_f_string_template(template: str) -> list[str]:
"""Validate an f-string template and return its input variables."""
input_variables = set()
for var, format_spec in _parse_f_string_fields(template):
if "." in var or "[" in var or "]" in var:
msg = (
f"Invalid variable name {var!r} in f-string template. "
f"Variable names cannot contain attribute "
f"access (.) or indexing ([])."
)
raise ValueError(msg)
if var.isdigit():
msg = (
f"Invalid variable name {var!r} in f-string template. "
f"Variable names cannot be all digits as they are interpreted "
f"as positional arguments."
)
raise ValueError(msg)
if format_spec and ("{" in format_spec or "}" in format_spec):
msg = (
"Invalid format specifier in f-string template. "
"Nested replacement fields are not allowed."
)
raise ValueError(msg)
input_variables.add(var)
return sorted(input_variables)
def check_valid_template(
template: str, template_format: str, input_variables: list[str]
) -> None:
@@ -285,6 +325,8 @@ def check_valid_template(
f" {list(DEFAULT_FORMATTER_MAPPING)}."
)
raise ValueError(msg) from exc
if template_format == "f-string":
validate_f_string_template(template)
try:
validator_func(template, input_variables)
except (KeyError, IndexError) as exc:
@@ -308,43 +350,18 @@ def get_template_variables(template: str, template_format: str) -> list[str]:
Raises:
ValueError: If the template format is not supported.
"""
input_variables: list[str] | set[str]
if template_format == "jinja2":
# Get the variables for the template
input_variables = _get_jinja2_variables_from_template(template)
input_variables = sorted(_get_jinja2_variables_from_template(template))
elif template_format == "f-string":
input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
input_variables = validate_f_string_template(template)
elif template_format == "mustache":
input_variables = mustache_template_vars(template)
else:
msg = f"Unsupported template format: {template_format}"
raise ValueError(msg)
# For f-strings, block attribute access and indexing syntax
# This prevents template injection attacks via accessing dangerous attributes
if template_format == "f-string":
for var in input_variables:
# Formatter().parse() returns field names with dots/brackets if present
# e.g., "obj.attr" or "obj[0]" - we need to block these
if "." in var or "[" in var or "]" in var:
msg = (
f"Invalid variable name {var!r} in f-string template. "
f"Variable names cannot contain attribute "
f"access (.) or indexing ([])."
)
raise ValueError(msg)
# Block variable names that are all digits (e.g., "0", "100")
# These are interpreted as positional arguments, not keyword arguments
if var.isdigit():
msg = (
f"Invalid variable name {var!r} in f-string template. "
f"Variable names cannot be all digits as they are interpreted "
f"as positional arguments."
)
raise ValueError(msg)
return sorted(input_variables)

View File

@@ -1300,6 +1300,24 @@ def test_fstring_rejects_invalid_identifier_variable_names() -> None:
assert result.messages[0].content == expected # type: ignore[attr-defined]
def test_fstring_rejects_nested_replacement_field_in_image_url() -> None:
with pytest.raises(ValueError, match="Nested replacement fields are not allowed"):
ChatPromptTemplate.from_messages(
[
(
"human",
[
{
"type": "image_url",
"image_url": {"url": "{img:{img.__class__.__name__}}"},
}
],
)
],
template_format="f-string",
)
def test_mustache_template_attribute_access_vulnerability() -> None:
"""Test that Mustache template injection is blocked.

View File

@@ -1,4 +1,9 @@
from langchain_core.load import load
import json
import pytest
from langchain_core.load import load, loads
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts.dict import DictPromptTemplate
@@ -32,3 +37,82 @@ def test_deserialize_legacy() -> None:
template={"type": "audio", "audio": "{audio_data}"}, template_format="f-string"
)
assert load(ser, allowed_objects=[DictPromptTemplate]) == expected
def test_dict_prompt_template_rejects_attribute_access_to_rich_objects() -> None:
with pytest.raises(ValueError, match="Variable names cannot contain attribute"):
DictPromptTemplate(
template={"output": "{message.additional_kwargs[secret]}"},
template_format="f-string",
)
def test_dict_prompt_template_loads_payload_rejects_attribute_access() -> None:
payload = json.dumps(
{
"lc": 1,
"type": "constructor",
"id": ["langchain_core", "prompts", "dict", "DictPromptTemplate"],
"kwargs": {
"template": {"output": "{message.additional_kwargs[secret]}"},
"template_format": "f-string",
},
}
)
with pytest.raises(ValueError, match="Variable names cannot contain attribute"):
loads(payload)
def test_dict_prompt_template_dumpd_round_trip_rejects_attribute_access() -> None:
payload = {
"lc": 1,
"type": "constructor",
"id": ["langchain_core", "prompts", "dict", "DictPromptTemplate"],
"kwargs": {
"template": {"output": "{message.additional_kwargs[secret]}"},
"template_format": "f-string",
},
}
with pytest.raises(ValueError, match="Variable names cannot contain attribute"):
load(payload, allowed_objects=[DictPromptTemplate])
def test_dict_prompt_template_deserialization_rejects_attribute_access() -> None:
payload = json.dumps(
{
"lc": 1,
"type": "constructor",
"id": ["langchain_core", "prompts", "dict", "DictPromptTemplate"],
"kwargs": {
"template": {"output": "{name.__class__.__name__}"},
"template_format": "f-string",
},
}
)
with pytest.raises(ValueError, match="Variable names cannot contain attribute"):
loads(payload)
def test_dict_prompt_template_legacy_deserialization_rejects_attribute_access() -> None:
ser = {
"type": "constructor",
"lc": 1,
"id": ["langchain_core", "prompts", "message", "_DictMessagePromptTemplate"],
"kwargs": {
"template_format": "f-string",
"template": {"output": "{name.__class__.__name__}"},
},
}
with pytest.raises(ValueError, match="Variable names cannot contain attribute"):
load(ser, allowed_objects=[DictPromptTemplate])
def test_prompt_template_blocks_attribute_access() -> None:
with pytest.raises(
ValueError, match="Variable names cannot contain attribute access"
):
PromptTemplate.from_template("{name.__class__}", template_format="f-string")

View File

@@ -1,7 +1,11 @@
import json
import pytest
from pydantic import ValidationError
from langchain_core.load import dump, loads
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts.image import ImagePromptTemplate
def test_image_prompt_template_deserializable() -> None:
@@ -15,6 +19,11 @@ def test_image_prompt_template_deserializable() -> None:
)
def test_image_prompt_template_invalid_template_type() -> None:
with pytest.raises(ValidationError):
ImagePromptTemplate(template=None)
def test_image_prompt_template_deserializable_old() -> None:
"""Test that the image prompt template is serializable."""
loads(

View File

@@ -1,7 +1,12 @@
import pytest
from packaging import version
from langchain_core.prompts.string import mustache_schema
from langchain_core.prompts.string import (
check_valid_template,
get_template_variables,
mustache_schema,
)
from langchain_core.utils.formatting import formatter
from langchain_core.utils.pydantic import PYDANTIC_VERSION
PYDANTIC_VERSION_AT_LEAST_29 = version.parse("2.9") <= PYDANTIC_VERSION
@@ -30,3 +35,47 @@ def test_mustache_schema_parent_child() -> None:
}
actual = mustache_schema(template).model_json_schema()
assert expected == actual
def test_get_template_variables_rejects_nested_replacement_field_in_format_spec() -> (
None
):
template = "{name:{name.__class__.__name__}}"
with pytest.raises(ValueError, match="Nested replacement fields are not allowed"):
get_template_variables(template, "f-string")
def test_formatter_rejects_nested_replacement_field_in_format_spec() -> None:
template = "{name:{name.__class__.__name__}}"
with pytest.raises(ValueError, match="Invalid format specifier"):
formatter.format(template, name="hello")
def test_check_valid_template_rejects_nested_replacement_field_in_format_spec() -> None:
template = "{name:{name.__class__.__name__}}"
with pytest.raises(ValueError, match="Nested replacement fields are not allowed"):
check_valid_template(template, "f-string", ["name"])
@pytest.mark.parametrize(
("template", "kwargs", "expected_variables", "expected_output"),
[
("{value:.2f}", {"value": 3.14159}, ["value"], "3.14"),
("{value:>10}", {"value": "cat"}, ["value"], " cat"),
("{value:*^10}", {"value": "cat"}, ["value"], "***cat****"),
("{value:,}", {"value": 1234567}, ["value"], "1,234,567"),
("{value:%}", {"value": 0.125}, ["value"], "12.500000%"),
("{value!r}", {"value": "cat"}, ["value"], "'cat'"),
],
)
def test_f_string_templates_allow_safe_format_specs(
template: str,
kwargs: dict[str, object],
expected_variables: list[str],
expected_output: str,
) -> None:
assert get_template_variables(template, "f-string") == expected_variables
assert formatter.format(template, **kwargs) == expected_output