diff --git a/libs/core/langchain_core/prompts/prompt.py b/libs/core/langchain_core/prompts/prompt.py index 0066445a050..e55d4c1df34 100644 --- a/libs/core/langchain_core/prompts/prompt.py +++ b/libs/core/langchain_core/prompts/prompt.py @@ -138,11 +138,8 @@ class PromptTemplate(StringPromptTemplate): """Override the + operator to allow for combining prompt templates.""" # Allow for easy combining if isinstance(other, PromptTemplate): - if self.template_format != "f-string": - msg = "Adding prompt templates only supported for f-strings." - raise ValueError(msg) - if other.template_format != "f-string": - msg = "Adding prompt templates only supported for f-strings." + if self.template_format != other.template_format: + msg = "Cannot add templates of different formats" raise ValueError(msg) input_variables = list( set(self.input_variables) | set(other.input_variables) @@ -160,11 +157,14 @@ class PromptTemplate(StringPromptTemplate): template=template, input_variables=input_variables, partial_variables=partial_variables, - template_format="f-string", + template_format=self.template_format, validate_template=validate_template, ) if isinstance(other, str): - prompt = PromptTemplate.from_template(other) + prompt = PromptTemplate.from_template( + other, + template_format=self.template_format, + ) return self + prompt msg = f"Unsupported operand type for +: {type(other)}" raise NotImplementedError(msg) diff --git a/libs/core/tests/unit_tests/prompts/test_prompt.py b/libs/core/tests/unit_tests/prompts/test_prompt.py index e092eb66581..bc4780e1541 100644 --- a/libs/core/tests/unit_tests/prompts/test_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_prompt.py @@ -1,7 +1,7 @@ """Test functionality related to prompts.""" import re -from typing import Any, Union +from typing import Any, Literal, Union from unittest import mock import pytest @@ -681,3 +681,54 @@ def test_prompt_with_template_variable_name_jinja2() -> None: template = "This is a {{template}} test." prompt = PromptTemplate.from_template(template, template_format="jinja2") assert prompt.invoke({"template": "bar"}).to_string() == "This is a bar test." + + +def test_prompt_template_add_with_with_another_format() -> None: + with pytest.raises(ValueError, match=r"Cannot add templates"): + ( + PromptTemplate.from_template("This is a {template}") + + PromptTemplate.from_template("So {{this}} is", template_format="mustache") + ) + + +@pytest.mark.parametrize( + ("template_format", "prompt1", "prompt2"), + [ + ("f-string", "This is a {variable}", ". This is {another_variable}"), + pytest.param( + "jinja2", + "This is a {{variable}}", + ". This is {{another_variable}}", + marks=[pytest.mark.requires("jinja2")], + ), + ("mustache", "This is a {{variable}}", ". This is {{another_variable}}"), + ], +) +def test_prompt_template_add( + template_format: Literal["f-string", "mustache", "jinja2"], + prompt1: str, + prompt2: str, +) -> None: + first_prompt = PromptTemplate.from_template( + prompt1, + template_format=template_format, + ) + second_prompt = PromptTemplate.from_template( + prompt2, + template_format=template_format, + ) + + concated_prompt = first_prompt + second_prompt + prompt_of_concated = PromptTemplate.from_template( + prompt1 + prompt2, + template_format=template_format, + ) + + assert concated_prompt.input_variables == prompt_of_concated.input_variables + assert concated_prompt.format( + variable="template", + another_variable="other_template", + ) == prompt_of_concated.format( + variable="template", + another_variable="other_template", + )