mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix(core): add more sanitization to templates (#36612)
add more sanitization to templates
This commit is contained in:
@@ -4,6 +4,7 @@ import warnings
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Any, Literal, cast
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
|
from pydantic import model_validator
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_core.load import dumpd
|
from langchain_core.load import dumpd
|
||||||
@@ -21,11 +22,35 @@ class DictPromptTemplate(RunnableSerializable[dict, dict]):
|
|||||||
Recognizes variables in f-string or mustache formatted string dict values.
|
Recognizes variables in f-string or mustache formatted string dict values.
|
||||||
|
|
||||||
Does NOT recognize variables in dict keys. Applies recursively.
|
Does NOT recognize variables in dict keys. Applies recursively.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
prompt = DictPromptTemplate(
|
||||||
|
template={
|
||||||
|
"type": "text",
|
||||||
|
"text": "Hello {name}",
|
||||||
|
"metadata": {"source": "{source}"},
|
||||||
|
},
|
||||||
|
template_format="f-string",
|
||||||
|
)
|
||||||
|
prompt.format(name="Alice", source="docs")
|
||||||
|
# {
|
||||||
|
# "type": "text",
|
||||||
|
# "text": "Hello Alice",
|
||||||
|
# "metadata": {"source": "docs"},
|
||||||
|
# }
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
template: dict[str, Any]
|
template: dict[str, Any]
|
||||||
template_format: Literal["f-string", "mustache"]
|
template_format: Literal["f-string", "mustache"]
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_template(self) -> "DictPromptTemplate":
|
||||||
|
"""Validate that the template structure contains only safe variables."""
|
||||||
|
_get_input_variables(self.template, self.template_format)
|
||||||
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_variables(self) -> list[str]:
|
def input_variables(self) -> list[str]:
|
||||||
"""Template input variables."""
|
"""Template input variables."""
|
||||||
|
|||||||
@@ -9,12 +9,25 @@ from langchain_core.prompts.base import BasePromptTemplate
|
|||||||
from langchain_core.prompts.string import (
|
from langchain_core.prompts.string import (
|
||||||
DEFAULT_FORMATTER_MAPPING,
|
DEFAULT_FORMATTER_MAPPING,
|
||||||
PromptTemplateFormat,
|
PromptTemplateFormat,
|
||||||
|
get_template_variables,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables import run_in_executor
|
from langchain_core.runnables import run_in_executor
|
||||||
|
|
||||||
|
|
||||||
class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
|
class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
|
||||||
"""Image prompt template for a multimodal model."""
|
"""Image prompt template for a multimodal model.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
prompt = ImagePromptTemplate(
|
||||||
|
input_variables=["image_id"],
|
||||||
|
template={"url": "https://example.com/{image_id}.png", "detail": "high"},
|
||||||
|
template_format="f-string",
|
||||||
|
)
|
||||||
|
prompt.format(image_id="cat")
|
||||||
|
# {"url": "https://example.com/cat.png", "detail": "high"}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
template: dict = Field(default_factory=dict)
|
template: dict = Field(default_factory=dict)
|
||||||
"""Template for the prompt."""
|
"""Template for the prompt."""
|
||||||
@@ -43,6 +56,13 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
|
|||||||
f" Found: {overlap}"
|
f" Found: {overlap}"
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
template = kwargs.get("template", {})
|
||||||
|
template_format = kwargs.get("template_format", "f-string")
|
||||||
|
for value in template.values():
|
||||||
|
if isinstance(value, str):
|
||||||
|
get_template_variables(value, template_format)
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -219,6 +219,46 @@ DEFAULT_VALIDATOR_MAPPING: dict[str, Callable] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_f_string_fields(template: str) -> list[tuple[str, str | None]]:
|
||||||
|
fields: list[tuple[str, str | None]] = []
|
||||||
|
for _, field_name, format_spec, _ in Formatter().parse(template):
|
||||||
|
if field_name is not None:
|
||||||
|
fields.append((field_name, format_spec))
|
||||||
|
return fields
|
||||||
|
|
||||||
|
|
||||||
|
def validate_f_string_template(template: str) -> list[str]:
|
||||||
|
"""Validate an f-string template and return its input variables."""
|
||||||
|
input_variables = set()
|
||||||
|
for var, format_spec in _parse_f_string_fields(template):
|
||||||
|
if "." in var or "[" in var or "]" in var:
|
||||||
|
msg = (
|
||||||
|
f"Invalid variable name {var!r} in f-string template. "
|
||||||
|
f"Variable names cannot contain attribute "
|
||||||
|
f"access (.) or indexing ([])."
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
if var.isdigit():
|
||||||
|
msg = (
|
||||||
|
f"Invalid variable name {var!r} in f-string template. "
|
||||||
|
f"Variable names cannot be all digits as they are interpreted "
|
||||||
|
f"as positional arguments."
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
if format_spec and ("{" in format_spec or "}" in format_spec):
|
||||||
|
msg = (
|
||||||
|
"Invalid format specifier in f-string template. "
|
||||||
|
"Nested replacement fields are not allowed."
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
input_variables.add(var)
|
||||||
|
|
||||||
|
return sorted(input_variables)
|
||||||
|
|
||||||
|
|
||||||
def check_valid_template(
|
def check_valid_template(
|
||||||
template: str, template_format: str, input_variables: list[str]
|
template: str, template_format: str, input_variables: list[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -243,6 +283,8 @@ def check_valid_template(
|
|||||||
f" {list(DEFAULT_FORMATTER_MAPPING)}."
|
f" {list(DEFAULT_FORMATTER_MAPPING)}."
|
||||||
)
|
)
|
||||||
raise ValueError(msg) from exc
|
raise ValueError(msg) from exc
|
||||||
|
if template_format == "f-string":
|
||||||
|
validate_f_string_template(template)
|
||||||
try:
|
try:
|
||||||
validator_func(template, input_variables)
|
validator_func(template, input_variables)
|
||||||
except (KeyError, IndexError) as exc:
|
except (KeyError, IndexError) as exc:
|
||||||
@@ -268,43 +310,18 @@ def get_template_variables(template: str, template_format: str) -> list[str]:
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If the template format is not supported.
|
ValueError: If the template format is not supported.
|
||||||
"""
|
"""
|
||||||
|
input_variables: list[str] | set[str]
|
||||||
if template_format == "jinja2":
|
if template_format == "jinja2":
|
||||||
# Get the variables for the template
|
# Get the variables for the template
|
||||||
input_variables = _get_jinja2_variables_from_template(template)
|
input_variables = sorted(_get_jinja2_variables_from_template(template))
|
||||||
elif template_format == "f-string":
|
elif template_format == "f-string":
|
||||||
input_variables = {
|
input_variables = validate_f_string_template(template)
|
||||||
v for _, v, _, _ in Formatter().parse(template) if v is not None
|
|
||||||
}
|
|
||||||
elif template_format == "mustache":
|
elif template_format == "mustache":
|
||||||
input_variables = mustache_template_vars(template)
|
input_variables = mustache_template_vars(template)
|
||||||
else:
|
else:
|
||||||
msg = f"Unsupported template format: {template_format}"
|
msg = f"Unsupported template format: {template_format}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
# For f-strings, block attribute access and indexing syntax
|
|
||||||
# This prevents template injection attacks via accessing dangerous attributes
|
|
||||||
if template_format == "f-string":
|
|
||||||
for var in input_variables:
|
|
||||||
# Formatter().parse() returns field names with dots/brackets if present
|
|
||||||
# e.g., "obj.attr" or "obj[0]" - we need to block these
|
|
||||||
if "." in var or "[" in var or "]" in var:
|
|
||||||
msg = (
|
|
||||||
f"Invalid variable name {var!r} in f-string template. "
|
|
||||||
f"Variable names cannot contain attribute "
|
|
||||||
f"access (.) or indexing ([])."
|
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
# Block variable names that are all digits (e.g., "0", "100")
|
|
||||||
# These are interpreted as positional arguments, not keyword arguments
|
|
||||||
if var.isdigit():
|
|
||||||
msg = (
|
|
||||||
f"Invalid variable name {var!r} in f-string template. "
|
|
||||||
f"Variable names cannot be all digits as they are interpreted "
|
|
||||||
f"as positional arguments."
|
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
return sorted(input_variables)
|
return sorted(input_variables)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1951,6 +1951,24 @@ def test_fstring_rejects_invalid_identifier_variable_names() -> None:
|
|||||||
assert result.messages[0].content == expected # type: ignore[attr-defined]
|
assert result.messages[0].content == expected # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
def test_fstring_rejects_nested_replacement_field_in_image_url() -> None:
|
||||||
|
with pytest.raises(ValueError, match="Nested replacement fields are not allowed"):
|
||||||
|
ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"human",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": "{img:{img.__class__.__name__}}"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
template_format="f-string",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_mustache_template_attribute_access_vulnerability() -> None:
|
def test_mustache_template_attribute_access_vulnerability() -> None:
|
||||||
"""Test that Mustache template injection is blocked.
|
"""Test that Mustache template injection is blocked.
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,9 @@
|
|||||||
from langchain_core.load import load
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain_core.load import load, loads
|
||||||
|
from langchain_core.prompts import PromptTemplate
|
||||||
from langchain_core.prompts.dict import DictPromptTemplate
|
from langchain_core.prompts.dict import DictPromptTemplate
|
||||||
|
|
||||||
|
|
||||||
@@ -32,3 +37,82 @@ def test_deserialize_legacy() -> None:
|
|||||||
template={"type": "audio", "audio": "{audio_data}"}, template_format="f-string"
|
template={"type": "audio", "audio": "{audio_data}"}, template_format="f-string"
|
||||||
)
|
)
|
||||||
assert load(ser, allowed_objects=[DictPromptTemplate]) == expected
|
assert load(ser, allowed_objects=[DictPromptTemplate]) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_dict_prompt_template_rejects_attribute_access_to_rich_objects() -> None:
|
||||||
|
with pytest.raises(ValueError, match="Variable names cannot contain attribute"):
|
||||||
|
DictPromptTemplate(
|
||||||
|
template={"output": "{message.additional_kwargs[secret]}"},
|
||||||
|
template_format="f-string",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dict_prompt_template_loads_payload_rejects_attribute_access() -> None:
|
||||||
|
payload = json.dumps(
|
||||||
|
{
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": ["langchain_core", "prompts", "dict", "DictPromptTemplate"],
|
||||||
|
"kwargs": {
|
||||||
|
"template": {"output": "{message.additional_kwargs[secret]}"},
|
||||||
|
"template_format": "f-string",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Variable names cannot contain attribute"):
|
||||||
|
loads(payload)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dict_prompt_template_dumpd_round_trip_rejects_attribute_access() -> None:
|
||||||
|
payload = {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": ["langchain_core", "prompts", "dict", "DictPromptTemplate"],
|
||||||
|
"kwargs": {
|
||||||
|
"template": {"output": "{message.additional_kwargs[secret]}"},
|
||||||
|
"template_format": "f-string",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Variable names cannot contain attribute"):
|
||||||
|
load(payload, allowed_objects=[DictPromptTemplate])
|
||||||
|
|
||||||
|
|
||||||
|
def test_dict_prompt_template_deserialization_rejects_attribute_access() -> None:
|
||||||
|
payload = json.dumps(
|
||||||
|
{
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": ["langchain_core", "prompts", "dict", "DictPromptTemplate"],
|
||||||
|
"kwargs": {
|
||||||
|
"template": {"output": "{name.__class__.__name__}"},
|
||||||
|
"template_format": "f-string",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Variable names cannot contain attribute"):
|
||||||
|
loads(payload)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dict_prompt_template_legacy_deserialization_rejects_attribute_access() -> None:
|
||||||
|
ser = {
|
||||||
|
"type": "constructor",
|
||||||
|
"lc": 1,
|
||||||
|
"id": ["langchain_core", "prompts", "message", "_DictMessagePromptTemplate"],
|
||||||
|
"kwargs": {
|
||||||
|
"template_format": "f-string",
|
||||||
|
"template": {"output": "{name.__class__.__name__}"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Variable names cannot contain attribute"):
|
||||||
|
load(ser, allowed_objects=[DictPromptTemplate])
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_template_blocks_attribute_access() -> None:
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="Variable names cannot contain attribute access"
|
||||||
|
):
|
||||||
|
PromptTemplate.from_template("{name.__class__}", template_format="f-string")
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from langchain_core.load import dump, loads
|
from langchain_core.load import dump, loads
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_core.prompts.image import ImagePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
def test_image_prompt_template_deserializable() -> None:
|
def test_image_prompt_template_deserializable() -> None:
|
||||||
@@ -107,3 +110,31 @@ def test_image_prompt_template_deserializable_old() -> None:
|
|||||||
}
|
}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_prompt_template_rejects_attribute_access_in_template_values() -> None:
|
||||||
|
with pytest.raises(ValueError, match="Variable names cannot contain attribute"):
|
||||||
|
ImagePromptTemplate(
|
||||||
|
input_variables=["image"],
|
||||||
|
template={"url": "https://example.com/{image.__class__.__name__}.png"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_prompt_template_deserialization_rejects_attribute_access() -> None:
|
||||||
|
payload = json.dumps(
|
||||||
|
{
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": ["langchain", "prompts", "image", "ImagePromptTemplate"],
|
||||||
|
"kwargs": {
|
||||||
|
"template": {
|
||||||
|
"url": "https://example.com/{image.__class__.__name__}.png"
|
||||||
|
},
|
||||||
|
"input_variables": ["image"],
|
||||||
|
"template_format": "f-string",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Variable names cannot contain attribute"):
|
||||||
|
loads(payload)
|
||||||
|
|||||||
@@ -1,7 +1,12 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from langchain_core.prompts.string import get_template_variables, mustache_schema
|
from langchain_core.prompts.string import (
|
||||||
|
check_valid_template,
|
||||||
|
get_template_variables,
|
||||||
|
mustache_schema,
|
||||||
|
)
|
||||||
|
from langchain_core.utils.formatting import formatter
|
||||||
from langchain_core.utils.pydantic import PYDANTIC_VERSION
|
from langchain_core.utils.pydantic import PYDANTIC_VERSION
|
||||||
|
|
||||||
PYDANTIC_VERSION_AT_LEAST_29 = version.parse("2.9") <= PYDANTIC_VERSION
|
PYDANTIC_VERSION_AT_LEAST_29 = version.parse("2.9") <= PYDANTIC_VERSION
|
||||||
@@ -39,3 +44,47 @@ def test_get_template_variables_mustache_nested() -> None:
|
|||||||
expected = ["user"]
|
expected = ["user"]
|
||||||
actual = get_template_variables(template, template_format)
|
actual = get_template_variables(template, template_format)
|
||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_template_variables_rejects_nested_replacement_field_in_format_spec() -> (
|
||||||
|
None
|
||||||
|
):
|
||||||
|
template = "{name:{name.__class__.__name__}}"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Nested replacement fields are not allowed"):
|
||||||
|
get_template_variables(template, "f-string")
|
||||||
|
|
||||||
|
|
||||||
|
def test_formatter_rejects_nested_replacement_field_in_format_spec() -> None:
|
||||||
|
template = "{name:{name.__class__.__name__}}"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Invalid format specifier"):
|
||||||
|
formatter.format(template, name="hello")
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_valid_template_rejects_nested_replacement_field_in_format_spec() -> None:
|
||||||
|
template = "{name:{name.__class__.__name__}}"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Nested replacement fields are not allowed"):
|
||||||
|
check_valid_template(template, "f-string", ["name"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("template", "kwargs", "expected_variables", "expected_output"),
|
||||||
|
[
|
||||||
|
("{value:.2f}", {"value": 3.14159}, ["value"], "3.14"),
|
||||||
|
("{value:>10}", {"value": "cat"}, ["value"], " cat"),
|
||||||
|
("{value:*^10}", {"value": "cat"}, ["value"], "***cat****"),
|
||||||
|
("{value:,}", {"value": 1234567}, ["value"], "1,234,567"),
|
||||||
|
("{value:%}", {"value": 0.125}, ["value"], "12.500000%"),
|
||||||
|
("{value!r}", {"value": "cat"}, ["value"], "'cat'"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_f_string_templates_allow_safe_format_specs(
|
||||||
|
template: str,
|
||||||
|
kwargs: dict[str, object],
|
||||||
|
expected_variables: list[str],
|
||||||
|
expected_output: str,
|
||||||
|
) -> None:
|
||||||
|
assert get_template_variables(template, "f-string") == expected_variables
|
||||||
|
assert formatter.format(template, **kwargs) == expected_output
|
||||||
|
|||||||
Reference in New Issue
Block a user