mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-04 18:53:02 +00:00
Merge ec15055241
into 0e287763cd
This commit is contained in:
commit
2fb9bb4920
@ -138,11 +138,8 @@ class PromptTemplate(StringPromptTemplate):
|
|||||||
"""Override the + operator to allow for combining prompt templates."""
|
"""Override the + operator to allow for combining prompt templates."""
|
||||||
# Allow for easy combining
|
# Allow for easy combining
|
||||||
if isinstance(other, PromptTemplate):
|
if isinstance(other, PromptTemplate):
|
||||||
if self.template_format != "f-string":
|
if self.template_format != other.template_format:
|
||||||
msg = "Adding prompt templates only supported for f-strings."
|
msg = "Cannot add templates of different formats"
|
||||||
raise ValueError(msg)
|
|
||||||
if other.template_format != "f-string":
|
|
||||||
msg = "Adding prompt templates only supported for f-strings."
|
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
input_variables = list(
|
input_variables = list(
|
||||||
set(self.input_variables) | set(other.input_variables)
|
set(self.input_variables) | set(other.input_variables)
|
||||||
@ -160,11 +157,14 @@ class PromptTemplate(StringPromptTemplate):
|
|||||||
template=template,
|
template=template,
|
||||||
input_variables=input_variables,
|
input_variables=input_variables,
|
||||||
partial_variables=partial_variables,
|
partial_variables=partial_variables,
|
||||||
template_format="f-string",
|
template_format=self.template_format,
|
||||||
validate_template=validate_template,
|
validate_template=validate_template,
|
||||||
)
|
)
|
||||||
if isinstance(other, str):
|
if isinstance(other, str):
|
||||||
prompt = PromptTemplate.from_template(other)
|
prompt = PromptTemplate.from_template(
|
||||||
|
other,
|
||||||
|
template_format=self.template_format,
|
||||||
|
)
|
||||||
return self + prompt
|
return self + prompt
|
||||||
msg = f"Unsupported operand type for +: {type(other)}"
|
msg = f"Unsupported operand type for +: {type(other)}"
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Test functionality related to prompts."""
|
"""Test functionality related to prompts."""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Any, Union
|
from typing import Any, Literal, Union
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -681,3 +681,54 @@ def test_prompt_with_template_variable_name_jinja2() -> None:
|
|||||||
template = "This is a {{template}} test."
|
template = "This is a {{template}} test."
|
||||||
prompt = PromptTemplate.from_template(template, template_format="jinja2")
|
prompt = PromptTemplate.from_template(template, template_format="jinja2")
|
||||||
assert prompt.invoke({"template": "bar"}).to_string() == "This is a bar test."
|
assert prompt.invoke({"template": "bar"}).to_string() == "This is a bar test."
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_template_add_with_with_another_format() -> None:
|
||||||
|
with pytest.raises(ValueError, match=r"Cannot add templates"):
|
||||||
|
(
|
||||||
|
PromptTemplate.from_template("This is a {template}")
|
||||||
|
+ PromptTemplate.from_template("So {{this}} is", template_format="mustache")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("template_format", "prompt1", "prompt2"),
|
||||||
|
[
|
||||||
|
("f-string", "This is a {variable}", ". This is {another_variable}"),
|
||||||
|
pytest.param(
|
||||||
|
"jinja2",
|
||||||
|
"This is a {{variable}}",
|
||||||
|
". This is {{another_variable}}",
|
||||||
|
marks=[pytest.mark.requires("jinja2")],
|
||||||
|
),
|
||||||
|
("mustache", "This is a {{variable}}", ". This is {{another_variable}}"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_prompt_template_add(
|
||||||
|
template_format: Literal["f-string", "mustache", "jinja2"],
|
||||||
|
prompt1: str,
|
||||||
|
prompt2: str,
|
||||||
|
) -> None:
|
||||||
|
first_prompt = PromptTemplate.from_template(
|
||||||
|
prompt1,
|
||||||
|
template_format=template_format,
|
||||||
|
)
|
||||||
|
second_prompt = PromptTemplate.from_template(
|
||||||
|
prompt2,
|
||||||
|
template_format=template_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
concated_prompt = first_prompt + second_prompt
|
||||||
|
prompt_of_concated = PromptTemplate.from_template(
|
||||||
|
prompt1 + prompt2,
|
||||||
|
template_format=template_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert concated_prompt.input_variables == prompt_of_concated.input_variables
|
||||||
|
assert concated_prompt.format(
|
||||||
|
variable="template",
|
||||||
|
another_variable="other_template",
|
||||||
|
) == prompt_of_concated.format(
|
||||||
|
variable="template",
|
||||||
|
another_variable="other_template",
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user