From 8b0060184d6510626f9ad5a8580dd555d2cd68a7 Mon Sep 17 00:00:00 2001 From: James Braza Date: Tue, 5 Dec 2023 16:13:08 -0500 Subject: [PATCH] Fixing empty input variable crashing `PromptTemplate` validations (#14314) - Fixes `input_variables=[""]` crashing validations with a template `"{}"` - Uses `__cause__` for proper `Exception` chaining in `check_valid_template` --- libs/core/langchain_core/prompts/string.py | 22 +++++++++---------- .../tests/unit_tests/prompts/test_prompt.py | 6 +++++ 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/libs/core/langchain_core/prompts/string.py b/libs/core/langchain_core/prompts/string.py index d61660f2b8a..e454b8280ae 100644 --- a/libs/core/langchain_core/prompts/string.py +++ b/libs/core/langchain_core/prompts/string.py @@ -106,20 +106,20 @@ def check_valid_template( Raises: ValueError: If the template format is not supported. """ - if template_format not in DEFAULT_FORMATTER_MAPPING: - valid_formats = list(DEFAULT_FORMATTER_MAPPING) - raise ValueError( - f"Invalid template format. Got `{template_format}`;" - f" should be one of {valid_formats}" - ) try: validator_func = DEFAULT_VALIDATOR_MAPPING[template_format] - validator_func(template, input_variables) - except KeyError as e: + except KeyError as exc: raise ValueError( - "Invalid prompt schema; check for mismatched or missing input parameters. " - + str(e) - ) + f"Invalid template format {template_format!r}, should be one of" + f" {list(DEFAULT_FORMATTER_MAPPING)}." + ) from exc + try: + validator_func(template, input_variables) + except (KeyError, IndexError) as exc: + raise ValueError( + "Invalid prompt schema; check for mismatched or missing input parameters" + f" from {input_variables}." + ) from exc def get_template_variables(template: str, template_format: str) -> List[str]: diff --git a/libs/core/tests/unit_tests/prompts/test_prompt.py b/libs/core/tests/unit_tests/prompts/test_prompt.py index f931dd80bc1..1b63b8859e1 100644 --- a/libs/core/tests/unit_tests/prompts/test_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_prompt.py @@ -47,6 +47,12 @@ def test_prompt_missing_input_variables() -> None: ).input_variables == ["foo"] +def test_prompt_empty_input_variable() -> None: + """Test error is raised when empty string input variable.""" + with pytest.raises(ValueError): + PromptTemplate(input_variables=[""], template="{}", validate_template=True) + + def test_prompt_extra_input_variables() -> None: """Test error is raised when there are too many input variables.""" template = "This is a {foo} test."