mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix(core): add more sanitization to templates (#36612)
add more sanitization to templates
This commit is contained in:
@@ -4,6 +4,7 @@ import warnings
|
||||
from functools import cached_property
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.load import dumpd
|
||||
@@ -21,11 +22,35 @@ class DictPromptTemplate(RunnableSerializable[dict, dict]):
|
||||
Recognizes variables in f-string or mustache formatted string dict values.
|
||||
|
||||
Does NOT recognize variables in dict keys. Applies recursively.
|
||||
|
||||
Example:
|
||||
```python
|
||||
prompt = DictPromptTemplate(
|
||||
template={
|
||||
"type": "text",
|
||||
"text": "Hello {name}",
|
||||
"metadata": {"source": "{source}"},
|
||||
},
|
||||
template_format="f-string",
|
||||
)
|
||||
prompt.format(name="Alice", source="docs")
|
||||
# {
|
||||
# "type": "text",
|
||||
# "text": "Hello Alice",
|
||||
# "metadata": {"source": "docs"},
|
||||
# }
|
||||
```
|
||||
"""
|
||||
|
||||
template: dict[str, Any]
|
||||
template_format: Literal["f-string", "mustache"]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_template(self) -> "DictPromptTemplate":
|
||||
"""Validate that the template structure contains only safe variables."""
|
||||
_get_input_variables(self.template, self.template_format)
|
||||
return self
|
||||
|
||||
@property
|
||||
def input_variables(self) -> list[str]:
|
||||
"""Template input variables."""
|
||||
|
||||
@@ -9,12 +9,25 @@ from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.string import (
|
||||
DEFAULT_FORMATTER_MAPPING,
|
||||
PromptTemplateFormat,
|
||||
get_template_variables,
|
||||
)
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
||||
|
||||
class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
|
||||
"""Image prompt template for a multimodal model."""
|
||||
"""Image prompt template for a multimodal model.
|
||||
|
||||
Example:
|
||||
```python
|
||||
prompt = ImagePromptTemplate(
|
||||
input_variables=["image_id"],
|
||||
template={"url": "https://example.com/{image_id}.png", "detail": "high"},
|
||||
template_format="f-string",
|
||||
)
|
||||
prompt.format(image_id="cat")
|
||||
# {"url": "https://example.com/cat.png", "detail": "high"}
|
||||
```
|
||||
"""
|
||||
|
||||
template: dict = Field(default_factory=dict)
|
||||
"""Template for the prompt."""
|
||||
@@ -43,6 +56,13 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
|
||||
f" Found: {overlap}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
template = kwargs.get("template", {})
|
||||
template_format = kwargs.get("template_format", "f-string")
|
||||
for value in template.values():
|
||||
if isinstance(value, str):
|
||||
get_template_variables(value, template_format)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
|
||||
@@ -219,6 +219,46 @@ DEFAULT_VALIDATOR_MAPPING: dict[str, Callable] = {
|
||||
}
|
||||
|
||||
|
||||
def _parse_f_string_fields(template: str) -> list[tuple[str, str | None]]:
|
||||
fields: list[tuple[str, str | None]] = []
|
||||
for _, field_name, format_spec, _ in Formatter().parse(template):
|
||||
if field_name is not None:
|
||||
fields.append((field_name, format_spec))
|
||||
return fields
|
||||
|
||||
|
||||
def validate_f_string_template(template: str) -> list[str]:
|
||||
"""Validate an f-string template and return its input variables."""
|
||||
input_variables = set()
|
||||
for var, format_spec in _parse_f_string_fields(template):
|
||||
if "." in var or "[" in var or "]" in var:
|
||||
msg = (
|
||||
f"Invalid variable name {var!r} in f-string template. "
|
||||
f"Variable names cannot contain attribute "
|
||||
f"access (.) or indexing ([])."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if var.isdigit():
|
||||
msg = (
|
||||
f"Invalid variable name {var!r} in f-string template. "
|
||||
f"Variable names cannot be all digits as they are interpreted "
|
||||
f"as positional arguments."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if format_spec and ("{" in format_spec or "}" in format_spec):
|
||||
msg = (
|
||||
"Invalid format specifier in f-string template. "
|
||||
"Nested replacement fields are not allowed."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
input_variables.add(var)
|
||||
|
||||
return sorted(input_variables)
|
||||
|
||||
|
||||
def check_valid_template(
|
||||
template: str, template_format: str, input_variables: list[str]
|
||||
) -> None:
|
||||
@@ -243,6 +283,8 @@ def check_valid_template(
|
||||
f" {list(DEFAULT_FORMATTER_MAPPING)}."
|
||||
)
|
||||
raise ValueError(msg) from exc
|
||||
if template_format == "f-string":
|
||||
validate_f_string_template(template)
|
||||
try:
|
||||
validator_func(template, input_variables)
|
||||
except (KeyError, IndexError) as exc:
|
||||
@@ -268,43 +310,18 @@ def get_template_variables(template: str, template_format: str) -> list[str]:
|
||||
Raises:
|
||||
ValueError: If the template format is not supported.
|
||||
"""
|
||||
input_variables: list[str] | set[str]
|
||||
if template_format == "jinja2":
|
||||
# Get the variables for the template
|
||||
input_variables = _get_jinja2_variables_from_template(template)
|
||||
input_variables = sorted(_get_jinja2_variables_from_template(template))
|
||||
elif template_format == "f-string":
|
||||
input_variables = {
|
||||
v for _, v, _, _ in Formatter().parse(template) if v is not None
|
||||
}
|
||||
input_variables = validate_f_string_template(template)
|
||||
elif template_format == "mustache":
|
||||
input_variables = mustache_template_vars(template)
|
||||
else:
|
||||
msg = f"Unsupported template format: {template_format}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# For f-strings, block attribute access and indexing syntax
|
||||
# This prevents template injection attacks via accessing dangerous attributes
|
||||
if template_format == "f-string":
|
||||
for var in input_variables:
|
||||
# Formatter().parse() returns field names with dots/brackets if present
|
||||
# e.g., "obj.attr" or "obj[0]" - we need to block these
|
||||
if "." in var or "[" in var or "]" in var:
|
||||
msg = (
|
||||
f"Invalid variable name {var!r} in f-string template. "
|
||||
f"Variable names cannot contain attribute "
|
||||
f"access (.) or indexing ([])."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
# Block variable names that are all digits (e.g., "0", "100")
|
||||
# These are interpreted as positional arguments, not keyword arguments
|
||||
if var.isdigit():
|
||||
msg = (
|
||||
f"Invalid variable name {var!r} in f-string template. "
|
||||
f"Variable names cannot be all digits as they are interpreted "
|
||||
f"as positional arguments."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
return sorted(input_variables)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user