diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index 7aa7ea58d9a..c96c936d80f 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -389,11 +389,12 @@ class BasePromptTemplate( directory_path = save_path.parent directory_path.mkdir(parents=True, exist_ok=True) - if save_path.suffix == ".json": - with save_path.open("w", encoding="utf-8") as f: + resolved_path = save_path.resolve() + if resolved_path.suffix == ".json": + with resolved_path.open("w", encoding="utf-8") as f: json.dump(prompt_dict, f, indent=4) - elif save_path.suffix.endswith((".yaml", ".yml")): - with save_path.open("w", encoding="utf-8") as f: + elif resolved_path.suffix.endswith((".yaml", ".yml")): + with resolved_path.open("w", encoding="utf-8") as f: yaml.dump(prompt_dict, f, default_flow_style=False) else: msg = f"{save_path} must be json or yaml" diff --git a/libs/core/tests/unit_tests/prompts/test_loading.py b/libs/core/tests/unit_tests/prompts/test_loading.py index 91af9257c0a..20c7399c1db 100644 --- a/libs/core/tests/unit_tests/prompts/test_loading.py +++ b/libs/core/tests/unit_tests/prompts/test_loading.py @@ -325,6 +325,22 @@ def test_symlink_jinja2_rce_is_blocked(tmp_path: Path) -> None: load_prompt_from_config(config, allow_dangerous_paths=True) +def test_save_symlink_to_py_is_blocked(tmp_path: Path) -> None: + """Test that save() resolves symlinks before checking the file extension.""" + target = tmp_path / "malicious.py" + symlink = tmp_path / "output.json" + symlink.symlink_to(target) + + prompt = PromptTemplate(input_variables=["name"], template="Hello {name}") + with ( + suppress_langchain_deprecation_warning(), + pytest.raises(ValueError, match="must be json or yaml"), + ): + prompt.save(symlink) + + assert not target.exists() + + def test_loading_few_shot_prompt_from_yaml() -> None: """Test loading few shot prompt from yaml.""" with change_directory(EXAMPLE_DIR), suppress_langchain_deprecation_warning():