core(patch):fix partial_variables not working with SystemMessagePromptTemplate (#20711)

- **Issue:**  close #17560
- @baskaryan, @eyurtsev
This commit is contained in:
Guangdong Liu 2024-06-04 07:22:42 +08:00 committed by GitHub
parent f2dd31b9e8
commit bc7e32f315
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 83 additions and 1 deletions

View File

@ -407,6 +407,8 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
cls: Type[_StringImageMessagePromptTemplateT], cls: Type[_StringImageMessagePromptTemplateT],
template: Union[str, List[Union[str, _TextTemplateParam, _ImageTemplateParam]]], template: Union[str, List[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
template_format: str = "f-string", template_format: str = "f-string",
*,
partial_variables: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> _StringImageMessagePromptTemplateT: ) -> _StringImageMessagePromptTemplateT:
"""Create a class from a string template. """Create a class from a string template.
@ -414,6 +416,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
Args: Args:
template: a template. template: a template.
template_format: format of the template. template_format: format of the template.
partial_variables: A dictionary of variables that can be used too partially.
**kwargs: keyword arguments to pass to the constructor. **kwargs: keyword arguments to pass to the constructor.
Returns: Returns:
@ -421,10 +424,16 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
""" """
if isinstance(template, str): if isinstance(template, str):
prompt: Union[StringPromptTemplate, List] = PromptTemplate.from_template( prompt: Union[StringPromptTemplate, List] = PromptTemplate.from_template(
template, template_format=template_format template,
template_format=template_format,
partial_variables=partial_variables,
) )
return cls(prompt=prompt, **kwargs) return cls(prompt=prompt, **kwargs)
elif isinstance(template, list): elif isinstance(template, list):
if (partial_variables is not None) and len(partial_variables) > 0:
raise ValueError(
"Partial variables are not supported for list of templates."
)
prompt = [] prompt = []
for tmpl in template: for tmpl in template:
if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl: if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl:

View File

@ -99,6 +99,79 @@ def test_create_chat_prompt_template_from_template_partial() -> None:
assert output_prompt.prompt == expected_prompt assert output_prompt.prompt == expected_prompt
def test_create_system_message_prompt_template_from_template_partial() -> None:
"""Create a system message prompt template with partials."""
graph_creator_content = """
Your instructions are:
{instructions}
History:
{history}
"""
json_prompt_instructions: dict = {}
graph_analyst_template = SystemMessagePromptTemplate.from_template(
template=graph_creator_content,
input_variables=["history"],
partial_variables={"instructions": json_prompt_instructions},
)
assert graph_analyst_template.format(history="history") == SystemMessage(
content="\n Your instructions are:\n "
" {}\n History:\n "
"history\n "
)
def test_create_system_message_prompt_list_template() -> None:
graph_creator_content1 = """
This is the prompt for the first test:
{variables}
"""
graph_creator_content2 = """
This is the prompt for the second test:
{variables}
"""
graph_analyst_template = SystemMessagePromptTemplate.from_template(
template=[graph_creator_content1, graph_creator_content2],
input_variables=["variables"],
)
assert graph_analyst_template.format(variables="foo") == SystemMessage(
content=[
{
"type": "text",
"text": "\n This is the prompt for the first test:\n foo\n ",
},
{
"type": "text",
"text": "\n This is the prompt for "
"the second test:\n foo\n ",
},
]
)
def test_create_system_message_prompt_list_template_partial_variables_not_null() -> (
None
):
graph_creator_content1 = """
This is the prompt for the first test:
{variables}
"""
graph_creator_content2 = """
This is the prompt for the second test:
{variables}
"""
try:
graph_analyst_template = SystemMessagePromptTemplate.from_template(
template=[graph_creator_content1, graph_creator_content2],
input_variables=["variables"],
partial_variables={"variables": "foo"},
)
graph_analyst_template.format(variables="foo")
except ValueError as e:
assert str(e) == "Partial variables are not supported for list of templates."
def test_message_prompt_template_from_template_file() -> None: def test_message_prompt_template_from_template_file() -> None:
expected = ChatMessagePromptTemplate( expected = ChatMessagePromptTemplate(
prompt=PromptTemplate( prompt=PromptTemplate(