This commit is contained in:
Sadra Barikbin 2025-07-28 17:40:39 -07:00 committed by GitHub
commit 2fb9bb4920
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 59 additions and 8 deletions

View File

@ -138,11 +138,8 @@ class PromptTemplate(StringPromptTemplate):
"""Override the + operator to allow for combining prompt templates.""" """Override the + operator to allow for combining prompt templates."""
# Allow for easy combining # Allow for easy combining
if isinstance(other, PromptTemplate): if isinstance(other, PromptTemplate):
if self.template_format != "f-string": if self.template_format != other.template_format:
msg = "Adding prompt templates only supported for f-strings." msg = "Cannot add templates of different formats"
raise ValueError(msg)
if other.template_format != "f-string":
msg = "Adding prompt templates only supported for f-strings."
raise ValueError(msg) raise ValueError(msg)
input_variables = list( input_variables = list(
set(self.input_variables) | set(other.input_variables) set(self.input_variables) | set(other.input_variables)
@ -160,11 +157,14 @@ class PromptTemplate(StringPromptTemplate):
template=template, template=template,
input_variables=input_variables, input_variables=input_variables,
partial_variables=partial_variables, partial_variables=partial_variables,
template_format="f-string", template_format=self.template_format,
validate_template=validate_template, validate_template=validate_template,
) )
if isinstance(other, str): if isinstance(other, str):
prompt = PromptTemplate.from_template(other) prompt = PromptTemplate.from_template(
other,
template_format=self.template_format,
)
return self + prompt return self + prompt
msg = f"Unsupported operand type for +: {type(other)}" msg = f"Unsupported operand type for +: {type(other)}"
raise NotImplementedError(msg) raise NotImplementedError(msg)

View File

@ -1,7 +1,7 @@
"""Test functionality related to prompts.""" """Test functionality related to prompts."""
import re import re
from typing import Any, Union from typing import Any, Literal, Union
from unittest import mock from unittest import mock
import pytest import pytest
@ -681,3 +681,54 @@ def test_prompt_with_template_variable_name_jinja2() -> None:
template = "This is a {{template}} test." template = "This is a {{template}} test."
prompt = PromptTemplate.from_template(template, template_format="jinja2") prompt = PromptTemplate.from_template(template, template_format="jinja2")
assert prompt.invoke({"template": "bar"}).to_string() == "This is a bar test." assert prompt.invoke({"template": "bar"}).to_string() == "This is a bar test."
def test_prompt_template_add_with_with_another_format() -> None:
with pytest.raises(ValueError, match=r"Cannot add templates"):
(
PromptTemplate.from_template("This is a {template}")
+ PromptTemplate.from_template("So {{this}} is", template_format="mustache")
)
@pytest.mark.parametrize(
("template_format", "prompt1", "prompt2"),
[
("f-string", "This is a {variable}", ". This is {another_variable}"),
pytest.param(
"jinja2",
"This is a {{variable}}",
". This is {{another_variable}}",
marks=[pytest.mark.requires("jinja2")],
),
("mustache", "This is a {{variable}}", ". This is {{another_variable}}"),
],
)
def test_prompt_template_add(
template_format: Literal["f-string", "mustache", "jinja2"],
prompt1: str,
prompt2: str,
) -> None:
first_prompt = PromptTemplate.from_template(
prompt1,
template_format=template_format,
)
second_prompt = PromptTemplate.from_template(
prompt2,
template_format=template_format,
)
concated_prompt = first_prompt + second_prompt
prompt_of_concated = PromptTemplate.from_template(
prompt1 + prompt2,
template_format=template_format,
)
assert concated_prompt.input_variables == prompt_of_concated.input_variables
assert concated_prompt.format(
variable="template",
another_variable="other_template",
) == prompt_of_concated.format(
variable="template",
another_variable="other_template",
)