mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-08 05:23:10 +00:00
fix(core): harden check for txt files in deprecated prompt loading functions (#36471)
This commit is contained in:
@@ -96,9 +96,12 @@ def _load_template(
|
||||
template_path = Path(config.pop(f"{var_name}_path"))
|
||||
if not allow_dangerous_paths:
|
||||
_validate_path(template_path)
|
||||
# Resolve symlinks before checking the suffix so that a symlink named
|
||||
# "exploit.txt" pointing to a non-.txt file is caught.
|
||||
resolved_path = template_path.resolve()
|
||||
# Load the template.
|
||||
if template_path.suffix == ".txt":
|
||||
template = template_path.read_text(encoding="utf-8")
|
||||
if resolved_path.suffix == ".txt":
|
||||
template = resolved_path.read_text(encoding="utf-8")
|
||||
else:
|
||||
raise ValueError
|
||||
# Set the template variable to the extracted variable.
|
||||
|
||||
@@ -278,6 +278,53 @@ def test_load_prompt_from_config_few_shot_rejects_absolute_example_prompt_path(
|
||||
load_prompt_from_config(config)
|
||||
|
||||
|
||||
def test_symlink_txt_to_py_is_blocked(tmp_path: Path) -> None:
|
||||
"""Test symlink redirects cannot get around file extension check."""
|
||||
sensitive = tmp_path / "sensitive_source.py"
|
||||
sensitive.write_text("INTERNAL_SECRET='ABC-123-XYZ'")
|
||||
symlink = tmp_path / "exploit_link.txt"
|
||||
symlink.symlink_to(sensitive)
|
||||
|
||||
config = {
|
||||
"_type": "prompt",
|
||||
"template_path": "exploit_link.txt",
|
||||
"input_variables": [],
|
||||
}
|
||||
original_dir = Path.cwd()
|
||||
try:
|
||||
os.chdir(tmp_path)
|
||||
with (
|
||||
suppress_langchain_deprecation_warning(),
|
||||
pytest.raises(ValueError), # noqa: PT011
|
||||
):
|
||||
load_prompt_from_config(config)
|
||||
finally:
|
||||
os.chdir(original_dir)
|
||||
|
||||
|
||||
def test_symlink_jinja2_rce_is_blocked(tmp_path: Path) -> None:
|
||||
"""Check jinja2 templates cannot be used to perform RCE via symlinks."""
|
||||
payload = tmp_path / "rce_payload.py"
|
||||
payload.write_text(
|
||||
"{{ self.__init__.__globals__.__builtins__"
|
||||
".__import__('os').popen('id').read() }}"
|
||||
)
|
||||
symlink = tmp_path / "rce_bypass.txt"
|
||||
symlink.symlink_to(payload)
|
||||
|
||||
config = {
|
||||
"_type": "prompt",
|
||||
"template_path": str(symlink),
|
||||
"template_format": "jinja2",
|
||||
"input_variables": [],
|
||||
}
|
||||
with (
|
||||
suppress_langchain_deprecation_warning(),
|
||||
pytest.raises(ValueError), # noqa: PT011
|
||||
):
|
||||
load_prompt_from_config(config, allow_dangerous_paths=True)
|
||||
|
||||
|
||||
def test_loading_few_shot_prompt_from_yaml() -> None:
|
||||
"""Test loading few shot prompt from yaml."""
|
||||
with change_directory(EXAMPLE_DIR), suppress_langchain_deprecation_warning():
|
||||
|
||||
Reference in New Issue
Block a user