core(patch):fix partial_variables not working with SystemMessagePromptTemplate (#20711)

- **Issue:**  close #17560
- @baskaryan, @eyurtsev
This commit is contained in:
Guangdong Liu
2024-06-04 07:22:42 +08:00
committed by GitHub
parent f2dd31b9e8
commit bc7e32f315
2 changed files with 83 additions and 1 deletions

View File

@@ -99,6 +99,79 @@ def test_create_chat_prompt_template_from_template_partial() -> None:
assert output_prompt.prompt == expected_prompt
def test_create_system_message_prompt_template_from_template_partial() -> None:
"""Create a system message prompt template with partials."""
graph_creator_content = """
Your instructions are:
{instructions}
History:
{history}
"""
json_prompt_instructions: dict = {}
graph_analyst_template = SystemMessagePromptTemplate.from_template(
template=graph_creator_content,
input_variables=["history"],
partial_variables={"instructions": json_prompt_instructions},
)
assert graph_analyst_template.format(history="history") == SystemMessage(
content="\n Your instructions are:\n "
" {}\n History:\n "
"history\n "
)
def test_create_system_message_prompt_list_template() -> None:
graph_creator_content1 = """
This is the prompt for the first test:
{variables}
"""
graph_creator_content2 = """
This is the prompt for the second test:
{variables}
"""
graph_analyst_template = SystemMessagePromptTemplate.from_template(
template=[graph_creator_content1, graph_creator_content2],
input_variables=["variables"],
)
assert graph_analyst_template.format(variables="foo") == SystemMessage(
content=[
{
"type": "text",
"text": "\n This is the prompt for the first test:\n foo\n ",
},
{
"type": "text",
"text": "\n This is the prompt for "
"the second test:\n foo\n ",
},
]
)
def test_create_system_message_prompt_list_template_partial_variables_not_null() -> (
None
):
graph_creator_content1 = """
This is the prompt for the first test:
{variables}
"""
graph_creator_content2 = """
This is the prompt for the second test:
{variables}
"""
try:
graph_analyst_template = SystemMessagePromptTemplate.from_template(
template=[graph_creator_content1, graph_creator_content2],
input_variables=["variables"],
partial_variables={"variables": "foo"},
)
graph_analyst_template.format(variables="foo")
except ValueError as e:
assert str(e) == "Partial variables are not supported for list of templates."
def test_message_prompt_template_from_template_file() -> None:
expected = ChatMessagePromptTemplate(
prompt=PromptTemplate(