core[patch]: dict chat prompt template support (#25674)

- Support passing dicts as templates to chat prompt template
- Support making *any* attribute on a message a runtime variable
- Significantly simpler than trying to update our existing prompt
template classes

```python
    template = ChatPromptTemplate(
        [
            {
                "role": "assistant",
                "content": [
                    {
                        "type": "text",
                        "text": "{text1}",
                        "cache_control": {"type": "ephemeral"},
                    },
                    {"type": "image_url", "image_url": {"path": "{local_image_path}"}},
                ],
                "name": "{name1}",
                "tool_calls": [
                    {
                        "name": "{tool_name1}",
                        "args": {"arg1": "{tool_arg1}"},
                        "id": "1",
                        "type": "tool_call",
                    }
                ],
            },
            {
                "role": "tool",
                "content": "{tool_content2}",
                "tool_call_id": "1",
                "name": "{tool_name1}",
            },
        ]
    )

```

will likely close #25514 if we like this idea and update to use this
logic

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Bagatur 2025-04-15 08:00:49 -07:00 committed by GitHub
parent 9cfe6bcacd
commit 7262de4217
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 476 additions and 103 deletions

View File

@ -25,7 +25,6 @@ from pydantic import (
from typing_extensions import Self, override
from langchain_core._api import deprecated
from langchain_core.load import Serializable
from langchain_core.messages import (
AIMessage,
AnyMessage,
@ -39,6 +38,10 @@ from langchain_core.messages.base import get_msg_title_repr
from langchain_core.prompt_values import ChatPromptValue, ImageURL, PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.image import ImagePromptTemplate
from langchain_core.prompts.message import (
BaseMessagePromptTemplate,
_DictMessagePromptTemplate,
)
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import (
PromptTemplateFormat,
@ -52,87 +55,6 @@ if TYPE_CHECKING:
from collections.abc import Sequence
class BaseMessagePromptTemplate(Serializable, ABC):
"""Base class for message prompt templates."""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable.
Returns: True.
"""
return True
@classmethod
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Default namespace is ["langchain", "prompts", "chat"].
"""
return ["langchain", "prompts", "chat"]
@abstractmethod
def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Format messages from kwargs. Should return a list of BaseMessages.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
List of BaseMessages.
"""
async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Async format messages from kwargs.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
List of BaseMessages.
"""
return self.format_messages(**kwargs)
@property
@abstractmethod
def input_variables(self) -> list[str]:
"""Input variables for this prompt template.
Returns:
List of input variables.
"""
def pretty_repr(
self,
html: bool = False, # noqa: FBT001,FBT002
) -> str:
"""Human-readable representation.
Args:
html: Whether to format as HTML. Defaults to False.
Returns:
Human-readable representation.
"""
raise NotImplementedError
def pretty_print(self) -> None:
"""Print a human-readable representation."""
print(self.pretty_repr(html=is_interactive_env())) # noqa: T201
def __add__(self, other: Any) -> ChatPromptTemplate:
"""Combine two prompt templates.
Args:
other: Another prompt template.
Returns:
Combined prompt template.
"""
prompt = ChatPromptTemplate(messages=[self])
return prompt + other
class MessagesPlaceholder(BaseMessagePromptTemplate):
"""Prompt template that assumes variable is already list of messages.
@ -473,7 +395,10 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
"""Human message prompt template. This is a message sent from the user."""
prompt: Union[
StringPromptTemplate, list[Union[StringPromptTemplate, ImagePromptTemplate]]
StringPromptTemplate,
list[
Union[StringPromptTemplate, ImagePromptTemplate, _DictMessagePromptTemplate]
],
]
"""Prompt template."""
additional_kwargs: dict = Field(default_factory=dict)
@ -484,7 +409,10 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
@classmethod
def from_template(
cls: type[Self],
template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
template: Union[
str,
list[Union[str, _TextTemplateParam, _ImageTemplateParam, dict[str, Any]]],
],
template_format: PromptTemplateFormat = "f-string",
*,
partial_variables: Optional[dict[str, Any]] = None,
@ -567,6 +495,19 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
msg = f"Invalid image template: {tmpl}"
raise ValueError(msg)
prompt.append(img_template_obj)
elif isinstance(tmpl, dict):
if template_format == "jinja2":
msg = (
"jinja2 is unsafe and is not supported for templates "
"expressed as dicts. Please use 'f-string' or 'mustache' "
"format."
)
raise ValueError(msg)
data_template_obj = _DictMessagePromptTemplate(
template=cast("dict[str, Any]", tmpl),
template_format=template_format,
)
prompt.append(data_template_obj)
else:
msg = f"Invalid template: {tmpl}"
raise ValueError(msg)
@ -644,11 +585,16 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
for prompt in self.prompt:
inputs = {var: kwargs[var] for var in prompt.input_variables}
if isinstance(prompt, StringPromptTemplate):
formatted: Union[str, ImageURL] = prompt.format(**inputs)
formatted: Union[str, ImageURL, dict[str, Any]] = prompt.format(
**inputs
)
content.append({"type": "text", "text": formatted})
elif isinstance(prompt, ImagePromptTemplate):
formatted = prompt.format(**inputs)
content.append({"type": "image_url", "image_url": formatted})
elif isinstance(prompt, _DictMessagePromptTemplate):
formatted = prompt.format(**inputs)
content.append(formatted)
return self._msg_class(
content=content, additional_kwargs=self.additional_kwargs
)
@ -671,11 +617,16 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
for prompt in self.prompt:
inputs = {var: kwargs[var] for var in prompt.input_variables}
if isinstance(prompt, StringPromptTemplate):
formatted: Union[str, ImageURL] = await prompt.aformat(**inputs)
formatted: Union[str, ImageURL, dict[str, Any]] = await prompt.aformat(
**inputs
)
content.append({"type": "text", "text": formatted})
elif isinstance(prompt, ImagePromptTemplate):
formatted = await prompt.aformat(**inputs)
content.append({"type": "image_url", "image_url": formatted})
elif isinstance(prompt, _DictMessagePromptTemplate):
formatted = prompt.format(**inputs)
content.append(formatted)
return self._msg_class(
content=content, additional_kwargs=self.additional_kwargs
)
@ -811,7 +762,7 @@ MessageLikeRepresentation = Union[
Union[str, list[dict], list[object]],
],
str,
dict,
dict[str, Any],
]
@ -984,7 +935,8 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
"""
_messages = [
_convert_to_message(message, template_format) for message in messages
_convert_to_message_template(message, template_format)
for message in messages
]
# Automatically infer input variables from messages
@ -1071,7 +1023,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
ValueError: If input variables do not match.
"""
messages = values["messages"]
input_vars = set()
input_vars: set = set()
optional_variables = set()
input_types: dict[str, Any] = values.get("input_types", {})
for message in messages:
@ -1126,7 +1078,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
return cls.from_messages([message])
@classmethod
@deprecated("0.0.1", alternative="from_messages classmethod", pending=True)
@deprecated("0.0.1", alternative="from_messages", pending=True)
def from_role_strings(
cls, string_messages: list[tuple[str, str]]
) -> ChatPromptTemplate:
@ -1146,7 +1098,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
)
@classmethod
@deprecated("0.0.1", alternative="from_messages classmethod", pending=True)
@deprecated("0.0.1", alternative="from_messages", pending=True)
def from_strings(
cls, string_messages: list[tuple[type[BaseMessagePromptTemplate], str]]
) -> ChatPromptTemplate:
@ -1297,7 +1249,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
Args:
message: representation of a message to append.
"""
self.messages.append(_convert_to_message(message))
self.messages.append(_convert_to_message_template(message))
def extend(self, messages: Sequence[MessageLikeRepresentation]) -> None:
"""Extend the chat template with a sequence of messages.
@ -1305,7 +1257,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
Args:
messages: sequence of message representations to append.
"""
self.messages.extend([_convert_to_message(message) for message in messages])
self.messages.extend(
[_convert_to_message_template(message) for message in messages]
)
@overload
def __getitem__(self, index: int) -> MessageLike: ...
@ -1425,7 +1379,7 @@ def _create_template_from_message_type(
return message
def _convert_to_message(
def _convert_to_message_template(
message: MessageLikeRepresentation,
template_format: PromptTemplateFormat = "f-string",
) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
@ -1488,3 +1442,7 @@ def _convert_to_message(
raise NotImplementedError(msg)
return _message
# For backwards compat:
_convert_to_message = _convert_to_message_template

View File

@ -14,10 +14,8 @@ from typing_extensions import override
from langchain_core.example_selectors import BaseExampleSelector
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.prompts.chat import (
BaseChatPromptTemplate,
BaseMessagePromptTemplate,
)
from langchain_core.prompts.chat import BaseChatPromptTemplate
from langchain_core.prompts.message import BaseMessagePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,

View File

@ -0,0 +1,186 @@
"""Message prompt templates."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Literal
from langchain_core.load import Serializable
from langchain_core.messages import BaseMessage, convert_to_messages
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,
get_template_variables,
)
from langchain_core.utils.interactive_env import is_interactive_env
if TYPE_CHECKING:
from langchain_core.prompts.chat import ChatPromptTemplate
class BaseMessagePromptTemplate(Serializable, ABC):
"""Base class for message prompt templates."""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable.
Returns: True.
"""
return True
@classmethod
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Default namespace is ["langchain", "prompts", "chat"].
"""
return ["langchain", "prompts", "chat"]
@abstractmethod
def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Format messages from kwargs. Should return a list of BaseMessages.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
List of BaseMessages.
"""
async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Async format messages from kwargs.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
List of BaseMessages.
"""
return self.format_messages(**kwargs)
@property
@abstractmethod
def input_variables(self) -> list[str]:
"""Input variables for this prompt template.
Returns:
List of input variables.
"""
def pretty_repr(
self,
html: bool = False, # noqa: FBT001,FBT002
) -> str:
"""Human-readable representation.
Args:
html: Whether to format as HTML. Defaults to False.
Returns:
Human-readable representation.
"""
raise NotImplementedError
def pretty_print(self) -> None:
"""Print a human-readable representation."""
print(self.pretty_repr(html=is_interactive_env())) # noqa: T201
def __add__(self, other: Any) -> ChatPromptTemplate:
"""Combine two prompt templates.
Args:
other: Another prompt template.
Returns:
Combined prompt template.
"""
from langchain_core.prompts.chat import ChatPromptTemplate
prompt = ChatPromptTemplate(messages=[self])
return prompt + other
class _DictMessagePromptTemplate(BaseMessagePromptTemplate):
"""Template represented by a dict that recursively fills input vars in string vals.
Special handling of image_url dicts to load local paths. These look like:
``{"type": "image_url", "image_url": {"path": "..."}}``
"""
template: dict[str, Any]
template_format: Literal["f-string", "mustache"]
def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
msg_dict = _insert_input_variables(self.template, kwargs, self.template_format)
return convert_to_messages([msg_dict])
@property
def input_variables(self) -> list[str]:
return _get_input_variables(self.template, self.template_format)
@property
def _prompt_type(self) -> str:
return "message-dict-prompt"
@classmethod
def get_lc_namespace(cls) -> list[str]:
return ["langchain_core", "prompts", "message"]
def format(
self,
**kwargs: Any,
) -> dict[str, Any]:
"""Format the prompt with the inputs."""
return _insert_input_variables(self.template, kwargs, self.template_format)
def _get_input_variables(
template: dict, template_format: Literal["f-string", "mustache"]
) -> list[str]:
input_variables = []
for v in template.values():
if isinstance(v, str):
input_variables += get_template_variables(v, template_format)
elif isinstance(v, dict):
input_variables += _get_input_variables(v, template_format)
elif isinstance(v, (list, tuple)):
for x in v:
if isinstance(x, str):
input_variables += get_template_variables(x, template_format)
elif isinstance(x, dict):
input_variables += _get_input_variables(x, template_format)
return list(set(input_variables))
def _insert_input_variables(
template: dict[str, Any],
inputs: dict[str, Any],
template_format: Literal["f-string", "mustache"],
) -> dict[str, Any]:
formatted = {}
formatter = DEFAULT_FORMATTER_MAPPING[template_format]
for k, v in template.items():
if isinstance(v, str):
formatted[k] = formatter(v, **inputs)
elif isinstance(v, dict):
# No longer support loading local images.
if k == "image_url" and "path" in v:
msg = (
"Specifying image inputs via file path in environments with "
"user-input paths is a security vulnerability. Out of an abundance "
"of caution, the utility has been removed to prevent possible "
"misuse."
)
raise ValueError(msg)
formatted[k] = _insert_input_variables(v, inputs, template_format)
elif isinstance(v, (list, tuple)):
formatted_v = []
for x in v:
if isinstance(x, str):
formatted_v.append(formatter(x, **inputs))
elif isinstance(x, dict):
formatted_v.append(
_insert_input_variables(x, inputs, template_format)
)
formatted[k] = type(v)(formatted_v)
return formatted

View File

@ -18,26 +18,29 @@ from langchain_core.messages import (
ChatMessage,
HumanMessage,
SystemMessage,
ToolMessage,
get_buffer_string,
)
from langchain_core.prompt_values import ChatPromptValue
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts.chat import (
AIMessagePromptTemplate,
BaseMessagePromptTemplate,
ChatMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
_convert_to_message,
_convert_to_message_template,
)
from langchain_core.prompts.message import BaseMessagePromptTemplate
from langchain_core.prompts.string import PromptTemplateFormat
from langchain_core.utils.pydantic import (
PYDANTIC_VERSION,
)
from tests.unit_tests.pydantic_utils import _normalize_schema
CUR_DIR = Path(__file__).parent.absolute().resolve()
@pytest.fixture
def messages() -> list[BaseMessagePromptTemplate]:
@ -521,7 +524,7 @@ def test_convert_to_message(
args: Any, expected: Union[BaseMessage, BaseMessagePromptTemplate]
) -> None:
"""Test convert to message."""
assert _convert_to_message(args) == expected
assert _convert_to_message_template(args) == expected
def test_chat_prompt_template_indexing() -> None:
@ -566,7 +569,7 @@ def test_convert_to_message_is_strict() -> None:
# meow does not correspond to a valid message type.
# this test is here to ensure that functionality to interpret `meow`
# as a role is NOT added.
_convert_to_message(("meow", "question"))
_convert_to_message_template(("meow", "question"))
def test_chat_message_partial() -> None:
@ -955,7 +958,7 @@ def test_chat_prompt_w_msgs_placeholder_ser_des(snapshot: SnapshotAssertion) ->
assert load(dumpd(prompt)) == prompt
async def test_chat_tmpl_serdes(snapshot: SnapshotAssertion) -> None:
def test_chat_tmpl_serdes(snapshot: SnapshotAssertion) -> None:
"""Test chat prompt template ser/des."""
template = ChatPromptTemplate(
[
@ -1006,6 +1009,89 @@ async def test_chat_tmpl_serdes(snapshot: SnapshotAssertion) -> None:
assert load(dumpd(template)) == template
@pytest.mark.xfail(
reason=(
"In a breaking release, we can update `_convert_to_message_template` to use "
"_DictMessagePromptTemplate for all `dict` inputs, allowing for templatization "
"of message attributes outside content blocks. That would enable the below "
"test to pass."
)
)
def test_chat_tmpl_dict_msg() -> None:
template = ChatPromptTemplate(
[
{
"role": "assistant",
"content": [
{
"type": "text",
"text": "{text1}",
"cache_control": {"type": "ephemeral"},
},
],
"name": "{name1}",
"tool_calls": [
{
"name": "{tool_name1}",
"args": {"arg1": "{tool_arg1}"},
"id": "1",
"type": "tool_call",
}
],
},
{
"role": "tool",
"content": "{tool_content2}",
"tool_call_id": "1",
"name": "{tool_name1}",
},
]
)
expected = [
AIMessage(
[
{
"type": "text",
"text": "important message",
"cache_control": {"type": "ephemeral"},
},
],
name="foo",
tool_calls=[
{
"name": "do_stuff",
"args": {"arg1": "important arg1"},
"id": "1",
"type": "tool_call",
}
],
),
ToolMessage("foo", name="do_stuff", tool_call_id="1"),
]
actual = template.invoke(
{
"text1": "important message",
"name1": "foo",
"tool_arg1": "important arg1",
"tool_name1": "do_stuff",
"tool_content2": "foo",
}
).to_messages()
assert actual == expected
partial_ = template.partial(text1="important message")
actual = partial_.invoke(
{
"name1": "foo",
"tool_arg1": "important arg1",
"tool_name1": "do_stuff",
"tool_content2": "foo",
}
).to_messages()
assert actual == expected
def test_chat_prompt_template_variable_names() -> None:
"""This test was written for an edge case that triggers a warning from Pydantic.
@ -1049,3 +1135,87 @@ def test_chat_prompt_template_variable_names() -> None:
"title": "PromptInput",
"type": "object",
}
def test_data_prompt_template_deserializable() -> None:
"""Test that the image prompt template is serializable."""
load(
dumpd(
ChatPromptTemplate.from_messages(
[
(
"system",
[{"type": "image", "source_type": "url", "url": "{url}"}],
)
]
)
)
)
@pytest.mark.requires("jinja2")
@pytest.mark.parametrize(
("template_format", "cache_control_placeholder", "source_data_placeholder"),
[
("f-string", "{cache_type}", "{source_data}"),
("mustache", "{{cache_type}}", "{{source_data}}"),
],
)
def test_chat_prompt_template_data_prompt_from_message(
template_format: PromptTemplateFormat,
cache_control_placeholder: str,
source_data_placeholder: str,
) -> None:
prompt: dict = {
"type": "image",
"source_type": "base64",
"data": f"{source_data_placeholder}",
}
template = ChatPromptTemplate.from_messages(
[("human", [prompt])], template_format=template_format
)
assert template.format_messages(source_data="base64data") == [
HumanMessage(
content=[
{
"type": "image",
"source_type": "base64",
"data": "base64data",
}
]
)
]
# metadata
prompt["metadata"] = {"cache_control": {"type": f"{cache_control_placeholder}"}}
template = ChatPromptTemplate.from_messages(
[("human", [prompt])], template_format=template_format
)
assert template.format_messages(
cache_type="ephemeral", source_data="base64data"
) == [
HumanMessage(
content=[
{
"type": "image",
"source_type": "base64",
"data": "base64data",
"metadata": {"cache_control": {"type": "ephemeral"}},
}
]
)
]
def test_dict_message_prompt_template_errors_on_jinja2() -> None:
prompt = {
"type": "image",
"source_type": "base64",
"data": "{source_data}",
}
with pytest.raises(ValueError, match="jinja2"):
_ = ChatPromptTemplate.from_messages(
[("human", [prompt])], template_format="jinja2"
)

View File

@ -0,0 +1,61 @@
from pathlib import Path
from langchain_core.messages import AIMessage, BaseMessage, ToolMessage
from langchain_core.prompts.message import _DictMessagePromptTemplate
CUR_DIR = Path(__file__).parent.absolute().resolve()
def test__dict_message_prompt_template_fstring() -> None:
template = {
"role": "assistant",
"content": [
{"type": "text", "text": "{text1}", "cache_control": {"type": "ephemeral"}},
],
"name": "{name1}",
"tool_calls": [
{
"name": "{tool_name1}",
"args": {"arg1": "{tool_arg1}"},
"id": "1",
"type": "tool_call",
}
],
}
prompt = _DictMessagePromptTemplate(template=template, template_format="f-string")
expected: BaseMessage = AIMessage(
[
{
"type": "text",
"text": "important message",
"cache_control": {"type": "ephemeral"},
},
],
name="foo",
tool_calls=[
{
"name": "do_stuff",
"args": {"arg1": "important arg1"},
"id": "1",
"type": "tool_call",
}
],
)
actual = prompt.format_messages(
text1="important message",
name1="foo",
tool_arg1="important arg1",
tool_name1="do_stuff",
)[0]
assert actual == expected
template = {
"role": "tool",
"content": "{content1}",
"tool_call_id": "1",
"name": "{name1}",
}
prompt = _DictMessagePromptTemplate(template=template, template_format="f-string")
expected = ToolMessage("foo", name="bar", tool_call_id="1")
actual = prompt.format_messages(content1="foo", name1="bar")[0]
assert actual == expected