mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-03 19:04:23 +00:00
fix(core): validate paths in prompt.save and load_prompt, deprecate methods (#36200)
This commit is contained in:
@@ -15,6 +15,7 @@ import yaml
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Self, override
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.exceptions import ErrorCode, create_message
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.output_parsers.base import BaseOutputParser # noqa: TC001
|
||||
@@ -350,6 +351,12 @@ class BasePromptTemplate(
|
||||
prompt_dict["_type"] = self._prompt_type
|
||||
return prompt_dict
|
||||
|
||||
@deprecated(
|
||||
since="1.2.21",
|
||||
removal="2.0.0",
|
||||
alternative="Use `dumpd`/`dumps` from `langchain_core.load` to serialize "
|
||||
"prompts and `load`/`loads` to deserialize them.",
|
||||
)
|
||||
def save(self, file_path: Path | str) -> None:
|
||||
"""Save the prompt.
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ from pydantic import (
|
||||
)
|
||||
from typing_extensions import Self, override
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
@@ -1305,6 +1306,12 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
"""Name of prompt type. Used for serialization."""
|
||||
return "chat"
|
||||
|
||||
@deprecated(
|
||||
since="1.2.21",
|
||||
removal="2.0.0",
|
||||
alternative="Use `dumpd`/`dumps` from `langchain_core.load` to serialize "
|
||||
"prompts and `load`/`loads` to deserialize them.",
|
||||
)
|
||||
def save(self, file_path: Path | str) -> None:
|
||||
"""Save prompt to file.
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from pydantic import (
|
||||
)
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.example_selectors import BaseExampleSelector
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.prompts.chat import BaseChatPromptTemplate
|
||||
@@ -237,6 +238,12 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
|
||||
"""Return the prompt type key."""
|
||||
return "few_shot"
|
||||
|
||||
@deprecated(
|
||||
since="1.2.21",
|
||||
removal="2.0.0",
|
||||
alternative="Use `dumpd`/`dumps` from `langchain_core.load` to serialize "
|
||||
"prompts and `load`/`loads` to deserialize them.",
|
||||
)
|
||||
def save(self, file_path: Path | str) -> None:
|
||||
"""Save the prompt template to a file.
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any
|
||||
from pydantic import ConfigDict, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.example_selectors import BaseExampleSelector
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompts.string import (
|
||||
@@ -215,6 +216,12 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
"""Return the prompt type key."""
|
||||
return "few_shot_with_templates"
|
||||
|
||||
@deprecated(
|
||||
since="1.2.21",
|
||||
removal="2.0.0",
|
||||
alternative="Use `dumpd`/`dumps` from `langchain_core.load` to serialize "
|
||||
"prompts and `load`/`loads` to deserialize them.",
|
||||
)
|
||||
def save(self, file_path: Path | str) -> None:
|
||||
"""Save the prompt to a file.
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.output_parsers.string import StrOutputParser
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
@@ -17,11 +18,51 @@ URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/pro
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_prompt_from_config(config: dict) -> BasePromptTemplate:
|
||||
def _validate_path(path: Path) -> None:
|
||||
"""Reject absolute paths and ``..`` traversal components.
|
||||
|
||||
Args:
|
||||
path: The path to validate.
|
||||
|
||||
Raises:
|
||||
ValueError: If the path is absolute or contains ``..`` components.
|
||||
"""
|
||||
if path.is_absolute():
|
||||
msg = (
|
||||
f"Path '{path}' is absolute. Absolute paths are not allowed "
|
||||
f"when loading prompt configurations to prevent path traversal "
|
||||
f"attacks. Use relative paths instead, or pass "
|
||||
f"`allow_dangerous_paths=True` if you trust the input."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if ".." in path.parts:
|
||||
msg = (
|
||||
f"Path '{path}' contains '..' components. Directory traversal "
|
||||
f"sequences are not allowed when loading prompt configurations. "
|
||||
f"Use direct relative paths instead, or pass "
|
||||
f"`allow_dangerous_paths=True` if you trust the input."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="1.2.21",
|
||||
removal="2.0.0",
|
||||
alternative="Use `dumpd`/`dumps` from `langchain_core.load` to serialize "
|
||||
"prompts and `load`/`loads` to deserialize them.",
|
||||
)
|
||||
def load_prompt_from_config(
|
||||
config: dict, *, allow_dangerous_paths: bool = False
|
||||
) -> BasePromptTemplate:
|
||||
"""Load prompt from config dict.
|
||||
|
||||
Args:
|
||||
config: Dict containing the prompt configuration.
|
||||
allow_dangerous_paths: If ``False`` (default), file paths in the
|
||||
config (such as ``template_path``, ``examples``, and
|
||||
``example_prompt_path``) are validated to reject absolute paths
|
||||
and directory traversal (``..``) sequences. Set to ``True`` only
|
||||
if you trust the source of the config.
|
||||
|
||||
Returns:
|
||||
A `PromptTemplate` object.
|
||||
@@ -38,10 +79,12 @@ def load_prompt_from_config(config: dict) -> BasePromptTemplate:
|
||||
raise ValueError(msg)
|
||||
|
||||
prompt_loader = type_to_loader_dict[config_type]
|
||||
return prompt_loader(config)
|
||||
return prompt_loader(config, allow_dangerous_paths=allow_dangerous_paths)
|
||||
|
||||
|
||||
def _load_template(var_name: str, config: dict) -> dict:
|
||||
def _load_template(
|
||||
var_name: str, config: dict, *, allow_dangerous_paths: bool = False
|
||||
) -> dict:
|
||||
"""Load template from the path if applicable."""
|
||||
# Check if template_path exists in config.
|
||||
if f"{var_name}_path" in config:
|
||||
@@ -51,6 +94,8 @@ def _load_template(var_name: str, config: dict) -> dict:
|
||||
raise ValueError(msg)
|
||||
# Pop the template path from the config.
|
||||
template_path = Path(config.pop(f"{var_name}_path"))
|
||||
if not allow_dangerous_paths:
|
||||
_validate_path(template_path)
|
||||
# Load the template.
|
||||
if template_path.suffix == ".txt":
|
||||
template = template_path.read_text(encoding="utf-8")
|
||||
@@ -61,12 +106,14 @@ def _load_template(var_name: str, config: dict) -> dict:
|
||||
return config
|
||||
|
||||
|
||||
def _load_examples(config: dict) -> dict:
|
||||
def _load_examples(config: dict, *, allow_dangerous_paths: bool = False) -> dict:
|
||||
"""Load examples if necessary."""
|
||||
if isinstance(config["examples"], list):
|
||||
pass
|
||||
elif isinstance(config["examples"], str):
|
||||
path = Path(config["examples"])
|
||||
if not allow_dangerous_paths:
|
||||
_validate_path(path)
|
||||
with path.open(encoding="utf-8") as f:
|
||||
if path.suffix == ".json":
|
||||
examples = json.load(f)
|
||||
@@ -92,11 +139,17 @@ def _load_output_parser(config: dict) -> dict:
|
||||
return config
|
||||
|
||||
|
||||
def _load_few_shot_prompt(config: dict) -> FewShotPromptTemplate:
|
||||
def _load_few_shot_prompt(
|
||||
config: dict, *, allow_dangerous_paths: bool = False
|
||||
) -> FewShotPromptTemplate:
|
||||
"""Load the "few shot" prompt from the config."""
|
||||
# Load the suffix and prefix templates.
|
||||
config = _load_template("suffix", config)
|
||||
config = _load_template("prefix", config)
|
||||
config = _load_template(
|
||||
"suffix", config, allow_dangerous_paths=allow_dangerous_paths
|
||||
)
|
||||
config = _load_template(
|
||||
"prefix", config, allow_dangerous_paths=allow_dangerous_paths
|
||||
)
|
||||
# Load the example prompt.
|
||||
if "example_prompt_path" in config:
|
||||
if "example_prompt" in config:
|
||||
@@ -105,19 +158,30 @@ def _load_few_shot_prompt(config: dict) -> FewShotPromptTemplate:
|
||||
"be specified."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
config["example_prompt"] = load_prompt(config.pop("example_prompt_path"))
|
||||
example_prompt_path = Path(config.pop("example_prompt_path"))
|
||||
if not allow_dangerous_paths:
|
||||
_validate_path(example_prompt_path)
|
||||
config["example_prompt"] = load_prompt(
|
||||
example_prompt_path, allow_dangerous_paths=allow_dangerous_paths
|
||||
)
|
||||
else:
|
||||
config["example_prompt"] = load_prompt_from_config(config["example_prompt"])
|
||||
config["example_prompt"] = load_prompt_from_config(
|
||||
config["example_prompt"], allow_dangerous_paths=allow_dangerous_paths
|
||||
)
|
||||
# Load the examples.
|
||||
config = _load_examples(config)
|
||||
config = _load_examples(config, allow_dangerous_paths=allow_dangerous_paths)
|
||||
config = _load_output_parser(config)
|
||||
return FewShotPromptTemplate(**config)
|
||||
|
||||
|
||||
def _load_prompt(config: dict) -> PromptTemplate:
|
||||
def _load_prompt(
|
||||
config: dict, *, allow_dangerous_paths: bool = False
|
||||
) -> PromptTemplate:
|
||||
"""Load the prompt template from config."""
|
||||
# Load the template from disk if necessary.
|
||||
config = _load_template("template", config)
|
||||
config = _load_template(
|
||||
"template", config, allow_dangerous_paths=allow_dangerous_paths
|
||||
)
|
||||
config = _load_output_parser(config)
|
||||
|
||||
template_format = config.get("template_format", "f-string")
|
||||
@@ -134,12 +198,28 @@ def _load_prompt(config: dict) -> PromptTemplate:
|
||||
return PromptTemplate(**config)
|
||||
|
||||
|
||||
def load_prompt(path: str | Path, encoding: str | None = None) -> BasePromptTemplate:
|
||||
@deprecated(
|
||||
since="1.2.21",
|
||||
removal="2.0.0",
|
||||
alternative="Use `dumpd`/`dumps` from `langchain_core.load` to serialize "
|
||||
"prompts and `load`/`loads` to deserialize them.",
|
||||
)
|
||||
def load_prompt(
|
||||
path: str | Path,
|
||||
encoding: str | None = None,
|
||||
*,
|
||||
allow_dangerous_paths: bool = False,
|
||||
) -> BasePromptTemplate:
|
||||
"""Unified method for loading a prompt from LangChainHub or local filesystem.
|
||||
|
||||
Args:
|
||||
path: Path to the prompt file.
|
||||
encoding: Encoding of the file.
|
||||
allow_dangerous_paths: If ``False`` (default), file paths referenced
|
||||
inside the loaded config (such as ``template_path``, ``examples``,
|
||||
and ``example_prompt_path``) are validated to reject absolute paths
|
||||
and directory traversal (``..``) sequences. Set to ``True`` only
|
||||
if you trust the source of the config.
|
||||
|
||||
Returns:
|
||||
A `PromptTemplate` object.
|
||||
@@ -154,11 +234,16 @@ def load_prompt(path: str | Path, encoding: str | None = None) -> BasePromptTemp
|
||||
"instead."
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
return _load_prompt_from_file(path, encoding)
|
||||
return _load_prompt_from_file(
|
||||
path, encoding, allow_dangerous_paths=allow_dangerous_paths
|
||||
)
|
||||
|
||||
|
||||
def _load_prompt_from_file(
|
||||
file: str | Path, encoding: str | None = None
|
||||
file: str | Path,
|
||||
encoding: str | None = None,
|
||||
*,
|
||||
allow_dangerous_paths: bool = False,
|
||||
) -> BasePromptTemplate:
|
||||
"""Load prompt from file."""
|
||||
# Convert file to a Path object.
|
||||
@@ -174,10 +259,14 @@ def _load_prompt_from_file(
|
||||
msg = f"Got unsupported file type {file_path.suffix}"
|
||||
raise ValueError(msg)
|
||||
# Load the prompt from the config now.
|
||||
return load_prompt_from_config(config)
|
||||
return load_prompt_from_config(config, allow_dangerous_paths=allow_dangerous_paths)
|
||||
|
||||
|
||||
def _load_chat_prompt(config: dict) -> ChatPromptTemplate:
|
||||
def _load_chat_prompt(
|
||||
config: dict,
|
||||
*,
|
||||
allow_dangerous_paths: bool = False, # noqa: ARG001
|
||||
) -> ChatPromptTemplate:
|
||||
"""Load chat prompt from config."""
|
||||
messages = config.pop("messages")
|
||||
template = messages[0]["prompt"].pop("template") if messages else None
|
||||
@@ -190,7 +279,7 @@ def _load_chat_prompt(config: dict) -> ChatPromptTemplate:
|
||||
return ChatPromptTemplate.from_template(template=template, **config)
|
||||
|
||||
|
||||
type_to_loader_dict: dict[str, Callable[[dict], BasePromptTemplate]] = {
|
||||
type_to_loader_dict: dict[str, Callable[..., BasePromptTemplate]] = {
|
||||
"prompt": _load_prompt,
|
||||
"few_shot": _load_few_shot_prompt,
|
||||
"chat": _load_chat_prompt,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Test loading functionality."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
@@ -7,8 +8,14 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core._api import suppress_langchain_deprecation_warning
|
||||
from langchain_core.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain_core.prompts.loading import load_prompt
|
||||
from langchain_core.prompts.loading import (
|
||||
_load_examples,
|
||||
_load_template,
|
||||
load_prompt,
|
||||
load_prompt_from_config,
|
||||
)
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
|
||||
EXAMPLE_DIR = (Path(__file__).parent.parent / "examples").absolute()
|
||||
@@ -27,7 +34,8 @@ def change_directory(dir_path: Path) -> Iterator[None]:
|
||||
|
||||
def test_loading_from_yaml() -> None:
|
||||
"""Test loading from yaml file."""
|
||||
prompt = load_prompt(EXAMPLE_DIR / "simple_prompt.yaml")
|
||||
with suppress_langchain_deprecation_warning():
|
||||
prompt = load_prompt(EXAMPLE_DIR / "simple_prompt.yaml")
|
||||
expected_prompt = PromptTemplate(
|
||||
input_variables=["adjective"],
|
||||
partial_variables={"content": "dogs"},
|
||||
@@ -38,7 +46,8 @@ def test_loading_from_yaml() -> None:
|
||||
|
||||
def test_loading_from_json() -> None:
|
||||
"""Test loading from json file."""
|
||||
prompt = load_prompt(EXAMPLE_DIR / "simple_prompt.json")
|
||||
with suppress_langchain_deprecation_warning():
|
||||
prompt = load_prompt(EXAMPLE_DIR / "simple_prompt.json")
|
||||
expected_prompt = PromptTemplate(
|
||||
input_variables=["adjective", "content"],
|
||||
template="Tell me a {adjective} joke about {content}.",
|
||||
@@ -49,14 +58,20 @@ def test_loading_from_json() -> None:
|
||||
def test_loading_jinja_from_json() -> None:
|
||||
"""Test that loading jinja2 format prompts from JSON raises ValueError."""
|
||||
prompt_path = EXAMPLE_DIR / "jinja_injection_prompt.json"
|
||||
with pytest.raises(ValueError, match=r".*can lead to arbitrary code execution.*"):
|
||||
with (
|
||||
suppress_langchain_deprecation_warning(),
|
||||
pytest.raises(ValueError, match=r".*can lead to arbitrary code execution.*"),
|
||||
):
|
||||
load_prompt(prompt_path)
|
||||
|
||||
|
||||
def test_loading_jinja_from_yaml() -> None:
|
||||
"""Test that loading jinja2 format prompts from YAML raises ValueError."""
|
||||
prompt_path = EXAMPLE_DIR / "jinja_injection_prompt.yaml"
|
||||
with pytest.raises(ValueError, match=r".*can lead to arbitrary code execution.*"):
|
||||
with (
|
||||
suppress_langchain_deprecation_warning(),
|
||||
pytest.raises(ValueError, match=r".*can lead to arbitrary code execution.*"),
|
||||
):
|
||||
load_prompt(prompt_path)
|
||||
|
||||
|
||||
@@ -66,8 +81,9 @@ def test_saving_loading_round_trip(tmp_path: Path) -> None:
|
||||
input_variables=["adjective", "content"],
|
||||
template="Tell me a {adjective} joke about {content}.",
|
||||
)
|
||||
simple_prompt.save(file_path=tmp_path / "prompt.yaml")
|
||||
loaded_prompt = load_prompt(tmp_path / "prompt.yaml")
|
||||
with suppress_langchain_deprecation_warning():
|
||||
simple_prompt.save(file_path=tmp_path / "prompt.yaml")
|
||||
loaded_prompt = load_prompt(tmp_path / "prompt.yaml")
|
||||
assert loaded_prompt == simple_prompt
|
||||
|
||||
few_shot_prompt = FewShotPromptTemplate(
|
||||
@@ -83,15 +99,18 @@ def test_saving_loading_round_trip(tmp_path: Path) -> None:
|
||||
],
|
||||
suffix="Input: {adjective}\nOutput:",
|
||||
)
|
||||
few_shot_prompt.save(file_path=tmp_path / "few_shot.yaml")
|
||||
loaded_prompt = load_prompt(tmp_path / "few_shot.yaml")
|
||||
with suppress_langchain_deprecation_warning():
|
||||
few_shot_prompt.save(file_path=tmp_path / "few_shot.yaml")
|
||||
loaded_prompt = load_prompt(tmp_path / "few_shot.yaml")
|
||||
assert loaded_prompt == few_shot_prompt
|
||||
|
||||
|
||||
def test_loading_with_template_as_file() -> None:
|
||||
"""Test loading when the template is a file."""
|
||||
with change_directory(EXAMPLE_DIR):
|
||||
prompt = load_prompt("simple_prompt_with_template_file.json")
|
||||
with change_directory(EXAMPLE_DIR), suppress_langchain_deprecation_warning():
|
||||
prompt = load_prompt(
|
||||
"simple_prompt_with_template_file.json", allow_dangerous_paths=True
|
||||
)
|
||||
expected_prompt = PromptTemplate(
|
||||
input_variables=["adjective", "content"],
|
||||
template="Tell me a {adjective} joke about {content}.",
|
||||
@@ -99,10 +118,170 @@ def test_loading_with_template_as_file() -> None:
|
||||
assert prompt == expected_prompt
|
||||
|
||||
|
||||
def test_load_template_rejects_absolute_path(tmp_path: Path) -> None:
|
||||
secret = tmp_path / "secret.txt"
|
||||
secret.write_text("SECRET")
|
||||
config = {"template_path": str(secret)}
|
||||
with pytest.raises(ValueError, match="is absolute"):
|
||||
_load_template("template", config)
|
||||
|
||||
|
||||
def test_load_template_rejects_traversal() -> None:
|
||||
config = {"template_path": "../../etc/secret.txt"}
|
||||
with pytest.raises(ValueError, match=r"contains '\.\.' components"):
|
||||
_load_template("template", config)
|
||||
|
||||
|
||||
def test_load_template_allows_dangerous_paths_when_opted_in(tmp_path: Path) -> None:
|
||||
secret = tmp_path / "secret.txt"
|
||||
secret.write_text("SECRET")
|
||||
config = {"template_path": str(secret)}
|
||||
result = _load_template("template", config, allow_dangerous_paths=True)
|
||||
assert result["template"] == "SECRET"
|
||||
|
||||
|
||||
def test_load_examples_rejects_absolute_path(tmp_path: Path) -> None:
|
||||
examples_file = tmp_path / "examples.json"
|
||||
examples_file.write_text(json.dumps([{"input": "a", "output": "b"}]))
|
||||
config = {"examples": str(examples_file)}
|
||||
with pytest.raises(ValueError, match="is absolute"):
|
||||
_load_examples(config)
|
||||
|
||||
|
||||
def test_load_examples_rejects_traversal() -> None:
|
||||
config = {"examples": "../../secrets/data.json"}
|
||||
with pytest.raises(ValueError, match=r"contains '\.\.' components"):
|
||||
_load_examples(config)
|
||||
|
||||
|
||||
def test_load_examples_allows_dangerous_paths_when_opted_in(tmp_path: Path) -> None:
|
||||
examples_file = tmp_path / "examples.json"
|
||||
examples_file.write_text(json.dumps([{"input": "a", "output": "b"}]))
|
||||
config = {"examples": str(examples_file)}
|
||||
result = _load_examples(config, allow_dangerous_paths=True)
|
||||
assert result["examples"] == [{"input": "a", "output": "b"}]
|
||||
|
||||
|
||||
def test_load_prompt_from_config_rejects_absolute_template_path(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
secret = tmp_path / "secret.txt"
|
||||
secret.write_text("SECRET")
|
||||
config = {
|
||||
"_type": "prompt",
|
||||
"template_path": str(secret),
|
||||
"input_variables": [],
|
||||
}
|
||||
with (
|
||||
suppress_langchain_deprecation_warning(),
|
||||
pytest.raises(ValueError, match="is absolute"),
|
||||
):
|
||||
load_prompt_from_config(config)
|
||||
|
||||
|
||||
def test_load_prompt_from_config_rejects_traversal_template_path() -> None:
|
||||
config = {
|
||||
"_type": "prompt",
|
||||
"template_path": "../../../tmp/secret.txt",
|
||||
"input_variables": [],
|
||||
}
|
||||
with (
|
||||
suppress_langchain_deprecation_warning(),
|
||||
pytest.raises(ValueError, match=r"contains '\.\.' components"),
|
||||
):
|
||||
load_prompt_from_config(config)
|
||||
|
||||
|
||||
def test_load_prompt_from_config_allows_dangerous_paths(tmp_path: Path) -> None:
|
||||
secret = tmp_path / "secret.txt"
|
||||
secret.write_text("SECRET")
|
||||
config = {
|
||||
"_type": "prompt",
|
||||
"template_path": str(secret),
|
||||
"input_variables": [],
|
||||
}
|
||||
with suppress_langchain_deprecation_warning():
|
||||
prompt = load_prompt_from_config(config, allow_dangerous_paths=True)
|
||||
assert isinstance(prompt, PromptTemplate)
|
||||
assert prompt.template == "SECRET"
|
||||
|
||||
|
||||
def test_load_prompt_from_config_few_shot_rejects_traversal_examples() -> None:
|
||||
config = {
|
||||
"_type": "few_shot",
|
||||
"input_variables": ["query"],
|
||||
"prefix": "Examples:",
|
||||
"example_prompt": {
|
||||
"_type": "prompt",
|
||||
"input_variables": ["input", "output"],
|
||||
"template": "{input}: {output}",
|
||||
},
|
||||
"examples": "../../../../.docker/config.json",
|
||||
"suffix": "Query: {query}",
|
||||
}
|
||||
with (
|
||||
suppress_langchain_deprecation_warning(),
|
||||
pytest.raises(ValueError, match=r"contains '\.\.' components"),
|
||||
):
|
||||
load_prompt_from_config(config)
|
||||
|
||||
|
||||
def test_load_prompt_from_config_few_shot_rejects_absolute_examples(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
examples_file = tmp_path / "examples.json"
|
||||
examples_file.write_text(json.dumps([{"input": "a", "output": "b"}]))
|
||||
config = {
|
||||
"_type": "few_shot",
|
||||
"input_variables": ["query"],
|
||||
"prefix": "Examples:",
|
||||
"example_prompt": {
|
||||
"_type": "prompt",
|
||||
"input_variables": ["input", "output"],
|
||||
"template": "{input}: {output}",
|
||||
},
|
||||
"examples": str(examples_file),
|
||||
"suffix": "Query: {query}",
|
||||
}
|
||||
with (
|
||||
suppress_langchain_deprecation_warning(),
|
||||
pytest.raises(ValueError, match="is absolute"),
|
||||
):
|
||||
load_prompt_from_config(config)
|
||||
|
||||
|
||||
def test_load_prompt_from_config_few_shot_rejects_absolute_example_prompt_path(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
prompt_file = tmp_path / "prompt.json"
|
||||
prompt_file.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"_type": "prompt",
|
||||
"template": "{input}: {output}",
|
||||
"input_variables": ["input", "output"],
|
||||
}
|
||||
)
|
||||
)
|
||||
config = {
|
||||
"_type": "few_shot",
|
||||
"input_variables": ["query"],
|
||||
"prefix": "Examples:",
|
||||
"example_prompt_path": str(prompt_file),
|
||||
"examples": [{"input": "a", "output": "b"}],
|
||||
"suffix": "Query: {query}",
|
||||
}
|
||||
with (
|
||||
suppress_langchain_deprecation_warning(),
|
||||
pytest.raises(ValueError, match="is absolute"),
|
||||
):
|
||||
load_prompt_from_config(config)
|
||||
|
||||
|
||||
def test_loading_few_shot_prompt_from_yaml() -> None:
|
||||
"""Test loading few shot prompt from yaml."""
|
||||
with change_directory(EXAMPLE_DIR):
|
||||
prompt = load_prompt("few_shot_prompt.yaml")
|
||||
with change_directory(EXAMPLE_DIR), suppress_langchain_deprecation_warning():
|
||||
prompt = load_prompt("few_shot_prompt.yaml", allow_dangerous_paths=True)
|
||||
expected_prompt = FewShotPromptTemplate(
|
||||
input_variables=["adjective"],
|
||||
prefix="Write antonyms for the following words.",
|
||||
@@ -121,8 +300,8 @@ def test_loading_few_shot_prompt_from_yaml() -> None:
|
||||
|
||||
def test_loading_few_shot_prompt_from_json() -> None:
|
||||
"""Test loading few shot prompt from json."""
|
||||
with change_directory(EXAMPLE_DIR):
|
||||
prompt = load_prompt("few_shot_prompt.json")
|
||||
with change_directory(EXAMPLE_DIR), suppress_langchain_deprecation_warning():
|
||||
prompt = load_prompt("few_shot_prompt.json", allow_dangerous_paths=True)
|
||||
expected_prompt = FewShotPromptTemplate(
|
||||
input_variables=["adjective"],
|
||||
prefix="Write antonyms for the following words.",
|
||||
@@ -141,8 +320,10 @@ def test_loading_few_shot_prompt_from_json() -> None:
|
||||
|
||||
def test_loading_few_shot_prompt_when_examples_in_config() -> None:
|
||||
"""Test loading few shot prompt when the examples are in the config."""
|
||||
with change_directory(EXAMPLE_DIR):
|
||||
prompt = load_prompt("few_shot_prompt_examples_in.json")
|
||||
with change_directory(EXAMPLE_DIR), suppress_langchain_deprecation_warning():
|
||||
prompt = load_prompt(
|
||||
"few_shot_prompt_examples_in.json", allow_dangerous_paths=True
|
||||
)
|
||||
expected_prompt = FewShotPromptTemplate(
|
||||
input_variables=["adjective"],
|
||||
prefix="Write antonyms for the following words.",
|
||||
@@ -161,8 +342,10 @@ def test_loading_few_shot_prompt_when_examples_in_config() -> None:
|
||||
|
||||
def test_loading_few_shot_prompt_example_prompt() -> None:
|
||||
"""Test loading few shot when the example prompt is in its own file."""
|
||||
with change_directory(EXAMPLE_DIR):
|
||||
prompt = load_prompt("few_shot_prompt_example_prompt.json")
|
||||
with change_directory(EXAMPLE_DIR), suppress_langchain_deprecation_warning():
|
||||
prompt = load_prompt(
|
||||
"few_shot_prompt_example_prompt.json", allow_dangerous_paths=True
|
||||
)
|
||||
expected_prompt = FewShotPromptTemplate(
|
||||
input_variables=["adjective"],
|
||||
prefix="Write antonyms for the following words.",
|
||||
|
||||
Reference in New Issue
Block a user