mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-08 12:31:49 +00:00
core[patch]: dict chat prompt template support (#25674)
- Support passing dicts as templates to chat prompt template - Support making *any* attribute on a message a runtime variable - Significantly simpler than trying to update our existing prompt template classes ```python template = ChatPromptTemplate( [ { "role": "assistant", "content": [ { "type": "text", "text": "{text1}", "cache_control": {"type": "ephemeral"}, }, {"type": "image_url", "image_url": {"path": "{local_image_path}"}}, ], "name": "{name1}", "tool_calls": [ { "name": "{tool_name1}", "args": {"arg1": "{tool_arg1}"}, "id": "1", "type": "tool_call", } ], }, { "role": "tool", "content": "{tool_content2}", "tool_call_id": "1", "name": "{tool_name1}", }, ] ) ``` will likely close #25514 if we like this idea and update to use this logic --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
9cfe6bcacd
commit
7262de4217
@ -25,7 +25,6 @@ from pydantic import (
|
|||||||
from typing_extensions import Self, override
|
from typing_extensions import Self, override
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.load import Serializable
|
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
AnyMessage,
|
AnyMessage,
|
||||||
@ -39,6 +38,10 @@ from langchain_core.messages.base import get_msg_title_repr
|
|||||||
from langchain_core.prompt_values import ChatPromptValue, ImageURL, PromptValue
|
from langchain_core.prompt_values import ChatPromptValue, ImageURL, PromptValue
|
||||||
from langchain_core.prompts.base import BasePromptTemplate
|
from langchain_core.prompts.base import BasePromptTemplate
|
||||||
from langchain_core.prompts.image import ImagePromptTemplate
|
from langchain_core.prompts.image import ImagePromptTemplate
|
||||||
|
from langchain_core.prompts.message import (
|
||||||
|
BaseMessagePromptTemplate,
|
||||||
|
_DictMessagePromptTemplate,
|
||||||
|
)
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
from langchain_core.prompts.string import (
|
from langchain_core.prompts.string import (
|
||||||
PromptTemplateFormat,
|
PromptTemplateFormat,
|
||||||
@ -52,87 +55,6 @@ if TYPE_CHECKING:
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
class BaseMessagePromptTemplate(Serializable, ABC):
|
|
||||||
"""Base class for message prompt templates."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_lc_serializable(cls) -> bool:
|
|
||||||
"""Return whether or not the class is serializable.
|
|
||||||
|
|
||||||
Returns: True.
|
|
||||||
"""
|
|
||||||
return True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_lc_namespace(cls) -> list[str]:
|
|
||||||
"""Get the namespace of the langchain object.
|
|
||||||
|
|
||||||
Default namespace is ["langchain", "prompts", "chat"].
|
|
||||||
"""
|
|
||||||
return ["langchain", "prompts", "chat"]
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
|
|
||||||
"""Format messages from kwargs. Should return a list of BaseMessages.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
**kwargs: Keyword arguments to use for formatting.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of BaseMessages.
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
|
|
||||||
"""Async format messages from kwargs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
**kwargs: Keyword arguments to use for formatting.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of BaseMessages.
|
|
||||||
"""
|
|
||||||
return self.format_messages(**kwargs)
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def input_variables(self) -> list[str]:
|
|
||||||
"""Input variables for this prompt template.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of input variables.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def pretty_repr(
|
|
||||||
self,
|
|
||||||
html: bool = False, # noqa: FBT001,FBT002
|
|
||||||
) -> str:
|
|
||||||
"""Human-readable representation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
html: Whether to format as HTML. Defaults to False.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Human-readable representation.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def pretty_print(self) -> None:
|
|
||||||
"""Print a human-readable representation."""
|
|
||||||
print(self.pretty_repr(html=is_interactive_env())) # noqa: T201
|
|
||||||
|
|
||||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
|
||||||
"""Combine two prompt templates.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
other: Another prompt template.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Combined prompt template.
|
|
||||||
"""
|
|
||||||
prompt = ChatPromptTemplate(messages=[self])
|
|
||||||
return prompt + other
|
|
||||||
|
|
||||||
|
|
||||||
class MessagesPlaceholder(BaseMessagePromptTemplate):
|
class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||||
"""Prompt template that assumes variable is already list of messages.
|
"""Prompt template that assumes variable is already list of messages.
|
||||||
|
|
||||||
@ -473,7 +395,10 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
"""Human message prompt template. This is a message sent from the user."""
|
"""Human message prompt template. This is a message sent from the user."""
|
||||||
|
|
||||||
prompt: Union[
|
prompt: Union[
|
||||||
StringPromptTemplate, list[Union[StringPromptTemplate, ImagePromptTemplate]]
|
StringPromptTemplate,
|
||||||
|
list[
|
||||||
|
Union[StringPromptTemplate, ImagePromptTemplate, _DictMessagePromptTemplate]
|
||||||
|
],
|
||||||
]
|
]
|
||||||
"""Prompt template."""
|
"""Prompt template."""
|
||||||
additional_kwargs: dict = Field(default_factory=dict)
|
additional_kwargs: dict = Field(default_factory=dict)
|
||||||
@ -484,7 +409,10 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_template(
|
def from_template(
|
||||||
cls: type[Self],
|
cls: type[Self],
|
||||||
template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
|
template: Union[
|
||||||
|
str,
|
||||||
|
list[Union[str, _TextTemplateParam, _ImageTemplateParam, dict[str, Any]]],
|
||||||
|
],
|
||||||
template_format: PromptTemplateFormat = "f-string",
|
template_format: PromptTemplateFormat = "f-string",
|
||||||
*,
|
*,
|
||||||
partial_variables: Optional[dict[str, Any]] = None,
|
partial_variables: Optional[dict[str, Any]] = None,
|
||||||
@ -567,6 +495,19 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
msg = f"Invalid image template: {tmpl}"
|
msg = f"Invalid image template: {tmpl}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
prompt.append(img_template_obj)
|
prompt.append(img_template_obj)
|
||||||
|
elif isinstance(tmpl, dict):
|
||||||
|
if template_format == "jinja2":
|
||||||
|
msg = (
|
||||||
|
"jinja2 is unsafe and is not supported for templates "
|
||||||
|
"expressed as dicts. Please use 'f-string' or 'mustache' "
|
||||||
|
"format."
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
data_template_obj = _DictMessagePromptTemplate(
|
||||||
|
template=cast("dict[str, Any]", tmpl),
|
||||||
|
template_format=template_format,
|
||||||
|
)
|
||||||
|
prompt.append(data_template_obj)
|
||||||
else:
|
else:
|
||||||
msg = f"Invalid template: {tmpl}"
|
msg = f"Invalid template: {tmpl}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
@ -644,11 +585,16 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
for prompt in self.prompt:
|
for prompt in self.prompt:
|
||||||
inputs = {var: kwargs[var] for var in prompt.input_variables}
|
inputs = {var: kwargs[var] for var in prompt.input_variables}
|
||||||
if isinstance(prompt, StringPromptTemplate):
|
if isinstance(prompt, StringPromptTemplate):
|
||||||
formatted: Union[str, ImageURL] = prompt.format(**inputs)
|
formatted: Union[str, ImageURL, dict[str, Any]] = prompt.format(
|
||||||
|
**inputs
|
||||||
|
)
|
||||||
content.append({"type": "text", "text": formatted})
|
content.append({"type": "text", "text": formatted})
|
||||||
elif isinstance(prompt, ImagePromptTemplate):
|
elif isinstance(prompt, ImagePromptTemplate):
|
||||||
formatted = prompt.format(**inputs)
|
formatted = prompt.format(**inputs)
|
||||||
content.append({"type": "image_url", "image_url": formatted})
|
content.append({"type": "image_url", "image_url": formatted})
|
||||||
|
elif isinstance(prompt, _DictMessagePromptTemplate):
|
||||||
|
formatted = prompt.format(**inputs)
|
||||||
|
content.append(formatted)
|
||||||
return self._msg_class(
|
return self._msg_class(
|
||||||
content=content, additional_kwargs=self.additional_kwargs
|
content=content, additional_kwargs=self.additional_kwargs
|
||||||
)
|
)
|
||||||
@ -671,11 +617,16 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
for prompt in self.prompt:
|
for prompt in self.prompt:
|
||||||
inputs = {var: kwargs[var] for var in prompt.input_variables}
|
inputs = {var: kwargs[var] for var in prompt.input_variables}
|
||||||
if isinstance(prompt, StringPromptTemplate):
|
if isinstance(prompt, StringPromptTemplate):
|
||||||
formatted: Union[str, ImageURL] = await prompt.aformat(**inputs)
|
formatted: Union[str, ImageURL, dict[str, Any]] = await prompt.aformat(
|
||||||
|
**inputs
|
||||||
|
)
|
||||||
content.append({"type": "text", "text": formatted})
|
content.append({"type": "text", "text": formatted})
|
||||||
elif isinstance(prompt, ImagePromptTemplate):
|
elif isinstance(prompt, ImagePromptTemplate):
|
||||||
formatted = await prompt.aformat(**inputs)
|
formatted = await prompt.aformat(**inputs)
|
||||||
content.append({"type": "image_url", "image_url": formatted})
|
content.append({"type": "image_url", "image_url": formatted})
|
||||||
|
elif isinstance(prompt, _DictMessagePromptTemplate):
|
||||||
|
formatted = prompt.format(**inputs)
|
||||||
|
content.append(formatted)
|
||||||
return self._msg_class(
|
return self._msg_class(
|
||||||
content=content, additional_kwargs=self.additional_kwargs
|
content=content, additional_kwargs=self.additional_kwargs
|
||||||
)
|
)
|
||||||
@ -811,7 +762,7 @@ MessageLikeRepresentation = Union[
|
|||||||
Union[str, list[dict], list[object]],
|
Union[str, list[dict], list[object]],
|
||||||
],
|
],
|
||||||
str,
|
str,
|
||||||
dict,
|
dict[str, Any],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -984,7 +935,8 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
_messages = [
|
_messages = [
|
||||||
_convert_to_message(message, template_format) for message in messages
|
_convert_to_message_template(message, template_format)
|
||||||
|
for message in messages
|
||||||
]
|
]
|
||||||
|
|
||||||
# Automatically infer input variables from messages
|
# Automatically infer input variables from messages
|
||||||
@ -1071,7 +1023,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
|||||||
ValueError: If input variables do not match.
|
ValueError: If input variables do not match.
|
||||||
"""
|
"""
|
||||||
messages = values["messages"]
|
messages = values["messages"]
|
||||||
input_vars = set()
|
input_vars: set = set()
|
||||||
optional_variables = set()
|
optional_variables = set()
|
||||||
input_types: dict[str, Any] = values.get("input_types", {})
|
input_types: dict[str, Any] = values.get("input_types", {})
|
||||||
for message in messages:
|
for message in messages:
|
||||||
@ -1126,7 +1078,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
|||||||
return cls.from_messages([message])
|
return cls.from_messages([message])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@deprecated("0.0.1", alternative="from_messages classmethod", pending=True)
|
@deprecated("0.0.1", alternative="from_messages", pending=True)
|
||||||
def from_role_strings(
|
def from_role_strings(
|
||||||
cls, string_messages: list[tuple[str, str]]
|
cls, string_messages: list[tuple[str, str]]
|
||||||
) -> ChatPromptTemplate:
|
) -> ChatPromptTemplate:
|
||||||
@ -1146,7 +1098,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@deprecated("0.0.1", alternative="from_messages classmethod", pending=True)
|
@deprecated("0.0.1", alternative="from_messages", pending=True)
|
||||||
def from_strings(
|
def from_strings(
|
||||||
cls, string_messages: list[tuple[type[BaseMessagePromptTemplate], str]]
|
cls, string_messages: list[tuple[type[BaseMessagePromptTemplate], str]]
|
||||||
) -> ChatPromptTemplate:
|
) -> ChatPromptTemplate:
|
||||||
@ -1297,7 +1249,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
|||||||
Args:
|
Args:
|
||||||
message: representation of a message to append.
|
message: representation of a message to append.
|
||||||
"""
|
"""
|
||||||
self.messages.append(_convert_to_message(message))
|
self.messages.append(_convert_to_message_template(message))
|
||||||
|
|
||||||
def extend(self, messages: Sequence[MessageLikeRepresentation]) -> None:
|
def extend(self, messages: Sequence[MessageLikeRepresentation]) -> None:
|
||||||
"""Extend the chat template with a sequence of messages.
|
"""Extend the chat template with a sequence of messages.
|
||||||
@ -1305,7 +1257,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
|||||||
Args:
|
Args:
|
||||||
messages: sequence of message representations to append.
|
messages: sequence of message representations to append.
|
||||||
"""
|
"""
|
||||||
self.messages.extend([_convert_to_message(message) for message in messages])
|
self.messages.extend(
|
||||||
|
[_convert_to_message_template(message) for message in messages]
|
||||||
|
)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self, index: int) -> MessageLike: ...
|
def __getitem__(self, index: int) -> MessageLike: ...
|
||||||
@ -1425,7 +1379,7 @@ def _create_template_from_message_type(
|
|||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_message(
|
def _convert_to_message_template(
|
||||||
message: MessageLikeRepresentation,
|
message: MessageLikeRepresentation,
|
||||||
template_format: PromptTemplateFormat = "f-string",
|
template_format: PromptTemplateFormat = "f-string",
|
||||||
) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
|
) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
|
||||||
@ -1488,3 +1442,7 @@ def _convert_to_message(
|
|||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
return _message
|
return _message
|
||||||
|
|
||||||
|
|
||||||
|
# For backwards compat:
|
||||||
|
_convert_to_message = _convert_to_message_template
|
||||||
|
@ -14,10 +14,8 @@ from typing_extensions import override
|
|||||||
|
|
||||||
from langchain_core.example_selectors import BaseExampleSelector
|
from langchain_core.example_selectors import BaseExampleSelector
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
from langchain_core.prompts.chat import (
|
from langchain_core.prompts.chat import BaseChatPromptTemplate
|
||||||
BaseChatPromptTemplate,
|
from langchain_core.prompts.message import BaseMessagePromptTemplate
|
||||||
BaseMessagePromptTemplate,
|
|
||||||
)
|
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
from langchain_core.prompts.string import (
|
from langchain_core.prompts.string import (
|
||||||
DEFAULT_FORMATTER_MAPPING,
|
DEFAULT_FORMATTER_MAPPING,
|
||||||
|
186
libs/core/langchain_core/prompts/message.py
Normal file
186
libs/core/langchain_core/prompts/message.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
"""Message prompt templates."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
|
|
||||||
|
from langchain_core.load import Serializable
|
||||||
|
from langchain_core.messages import BaseMessage, convert_to_messages
|
||||||
|
from langchain_core.prompts.string import (
|
||||||
|
DEFAULT_FORMATTER_MAPPING,
|
||||||
|
get_template_variables,
|
||||||
|
)
|
||||||
|
from langchain_core.utils.interactive_env import is_interactive_env
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMessagePromptTemplate(Serializable, ABC):
|
||||||
|
"""Base class for message prompt templates."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_lc_serializable(cls) -> bool:
|
||||||
|
"""Return whether or not the class is serializable.
|
||||||
|
|
||||||
|
Returns: True.
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
|
"""Get the namespace of the langchain object.
|
||||||
|
|
||||||
|
Default namespace is ["langchain", "prompts", "chat"].
|
||||||
|
"""
|
||||||
|
return ["langchain", "prompts", "chat"]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
|
||||||
|
"""Format messages from kwargs. Should return a list of BaseMessages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Keyword arguments to use for formatting.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of BaseMessages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:
|
||||||
|
"""Async format messages from kwargs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Keyword arguments to use for formatting.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of BaseMessages.
|
||||||
|
"""
|
||||||
|
return self.format_messages(**kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def input_variables(self) -> list[str]:
|
||||||
|
"""Input variables for this prompt template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of input variables.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def pretty_repr(
|
||||||
|
self,
|
||||||
|
html: bool = False, # noqa: FBT001,FBT002
|
||||||
|
) -> str:
|
||||||
|
"""Human-readable representation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
html: Whether to format as HTML. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Human-readable representation.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def pretty_print(self) -> None:
|
||||||
|
"""Print a human-readable representation."""
|
||||||
|
print(self.pretty_repr(html=is_interactive_env())) # noqa: T201
|
||||||
|
|
||||||
|
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||||
|
"""Combine two prompt templates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other: Another prompt template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined prompt template.
|
||||||
|
"""
|
||||||
|
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||||
|
|
||||||
|
prompt = ChatPromptTemplate(messages=[self])
|
||||||
|
return prompt + other
|
||||||
|
|
||||||
|
|
||||||
|
class _DictMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||||
|
"""Template represented by a dict that recursively fills input vars in string vals.
|
||||||
|
|
||||||
|
Special handling of image_url dicts to load local paths. These look like:
|
||||||
|
``{"type": "image_url", "image_url": {"path": "..."}}``
|
||||||
|
"""
|
||||||
|
|
||||||
|
template: dict[str, Any]
|
||||||
|
template_format: Literal["f-string", "mustache"]
|
||||||
|
|
||||||
|
def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
|
||||||
|
msg_dict = _insert_input_variables(self.template, kwargs, self.template_format)
|
||||||
|
return convert_to_messages([msg_dict])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_variables(self) -> list[str]:
|
||||||
|
return _get_input_variables(self.template, self.template_format)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _prompt_type(self) -> str:
|
||||||
|
return "message-dict-prompt"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
|
return ["langchain_core", "prompts", "message"]
|
||||||
|
|
||||||
|
def format(
|
||||||
|
self,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Format the prompt with the inputs."""
|
||||||
|
return _insert_input_variables(self.template, kwargs, self.template_format)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_input_variables(
|
||||||
|
template: dict, template_format: Literal["f-string", "mustache"]
|
||||||
|
) -> list[str]:
|
||||||
|
input_variables = []
|
||||||
|
for v in template.values():
|
||||||
|
if isinstance(v, str):
|
||||||
|
input_variables += get_template_variables(v, template_format)
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
input_variables += _get_input_variables(v, template_format)
|
||||||
|
elif isinstance(v, (list, tuple)):
|
||||||
|
for x in v:
|
||||||
|
if isinstance(x, str):
|
||||||
|
input_variables += get_template_variables(x, template_format)
|
||||||
|
elif isinstance(x, dict):
|
||||||
|
input_variables += _get_input_variables(x, template_format)
|
||||||
|
return list(set(input_variables))
|
||||||
|
|
||||||
|
|
||||||
|
def _insert_input_variables(
|
||||||
|
template: dict[str, Any],
|
||||||
|
inputs: dict[str, Any],
|
||||||
|
template_format: Literal["f-string", "mustache"],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
formatted = {}
|
||||||
|
formatter = DEFAULT_FORMATTER_MAPPING[template_format]
|
||||||
|
for k, v in template.items():
|
||||||
|
if isinstance(v, str):
|
||||||
|
formatted[k] = formatter(v, **inputs)
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
# No longer support loading local images.
|
||||||
|
if k == "image_url" and "path" in v:
|
||||||
|
msg = (
|
||||||
|
"Specifying image inputs via file path in environments with "
|
||||||
|
"user-input paths is a security vulnerability. Out of an abundance "
|
||||||
|
"of caution, the utility has been removed to prevent possible "
|
||||||
|
"misuse."
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
formatted[k] = _insert_input_variables(v, inputs, template_format)
|
||||||
|
elif isinstance(v, (list, tuple)):
|
||||||
|
formatted_v = []
|
||||||
|
for x in v:
|
||||||
|
if isinstance(x, str):
|
||||||
|
formatted_v.append(formatter(x, **inputs))
|
||||||
|
elif isinstance(x, dict):
|
||||||
|
formatted_v.append(
|
||||||
|
_insert_input_variables(x, inputs, template_format)
|
||||||
|
)
|
||||||
|
formatted[k] = type(v)(formatted_v)
|
||||||
|
return formatted
|
@ -18,26 +18,29 @@ from langchain_core.messages import (
|
|||||||
ChatMessage,
|
ChatMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
|
ToolMessage,
|
||||||
get_buffer_string,
|
get_buffer_string,
|
||||||
)
|
)
|
||||||
from langchain_core.prompt_values import ChatPromptValue
|
from langchain_core.prompt_values import ChatPromptValue
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from langchain_core.prompts.chat import (
|
from langchain_core.prompts.chat import (
|
||||||
AIMessagePromptTemplate,
|
AIMessagePromptTemplate,
|
||||||
BaseMessagePromptTemplate,
|
|
||||||
ChatMessagePromptTemplate,
|
ChatMessagePromptTemplate,
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
MessagesPlaceholder,
|
MessagesPlaceholder,
|
||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
_convert_to_message,
|
_convert_to_message_template,
|
||||||
)
|
)
|
||||||
|
from langchain_core.prompts.message import BaseMessagePromptTemplate
|
||||||
from langchain_core.prompts.string import PromptTemplateFormat
|
from langchain_core.prompts.string import PromptTemplateFormat
|
||||||
from langchain_core.utils.pydantic import (
|
from langchain_core.utils.pydantic import (
|
||||||
PYDANTIC_VERSION,
|
PYDANTIC_VERSION,
|
||||||
)
|
)
|
||||||
from tests.unit_tests.pydantic_utils import _normalize_schema
|
from tests.unit_tests.pydantic_utils import _normalize_schema
|
||||||
|
|
||||||
|
CUR_DIR = Path(__file__).parent.absolute().resolve()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def messages() -> list[BaseMessagePromptTemplate]:
|
def messages() -> list[BaseMessagePromptTemplate]:
|
||||||
@ -521,7 +524,7 @@ def test_convert_to_message(
|
|||||||
args: Any, expected: Union[BaseMessage, BaseMessagePromptTemplate]
|
args: Any, expected: Union[BaseMessage, BaseMessagePromptTemplate]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test convert to message."""
|
"""Test convert to message."""
|
||||||
assert _convert_to_message(args) == expected
|
assert _convert_to_message_template(args) == expected
|
||||||
|
|
||||||
|
|
||||||
def test_chat_prompt_template_indexing() -> None:
|
def test_chat_prompt_template_indexing() -> None:
|
||||||
@ -566,7 +569,7 @@ def test_convert_to_message_is_strict() -> None:
|
|||||||
# meow does not correspond to a valid message type.
|
# meow does not correspond to a valid message type.
|
||||||
# this test is here to ensure that functionality to interpret `meow`
|
# this test is here to ensure that functionality to interpret `meow`
|
||||||
# as a role is NOT added.
|
# as a role is NOT added.
|
||||||
_convert_to_message(("meow", "question"))
|
_convert_to_message_template(("meow", "question"))
|
||||||
|
|
||||||
|
|
||||||
def test_chat_message_partial() -> None:
|
def test_chat_message_partial() -> None:
|
||||||
@ -955,7 +958,7 @@ def test_chat_prompt_w_msgs_placeholder_ser_des(snapshot: SnapshotAssertion) ->
|
|||||||
assert load(dumpd(prompt)) == prompt
|
assert load(dumpd(prompt)) == prompt
|
||||||
|
|
||||||
|
|
||||||
async def test_chat_tmpl_serdes(snapshot: SnapshotAssertion) -> None:
|
def test_chat_tmpl_serdes(snapshot: SnapshotAssertion) -> None:
|
||||||
"""Test chat prompt template ser/des."""
|
"""Test chat prompt template ser/des."""
|
||||||
template = ChatPromptTemplate(
|
template = ChatPromptTemplate(
|
||||||
[
|
[
|
||||||
@ -1006,6 +1009,89 @@ async def test_chat_tmpl_serdes(snapshot: SnapshotAssertion) -> None:
|
|||||||
assert load(dumpd(template)) == template
|
assert load(dumpd(template)) == template
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.xfail(
|
||||||
|
reason=(
|
||||||
|
"In a breaking release, we can update `_convert_to_message_template` to use "
|
||||||
|
"_DictMessagePromptTemplate for all `dict` inputs, allowing for templatization "
|
||||||
|
"of message attributes outside content blocks. That would enable the below "
|
||||||
|
"test to pass."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def test_chat_tmpl_dict_msg() -> None:
|
||||||
|
template = ChatPromptTemplate(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "{text1}",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"name": "{name1}",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"name": "{tool_name1}",
|
||||||
|
"args": {"arg1": "{tool_arg1}"},
|
||||||
|
"id": "1",
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"content": "{tool_content2}",
|
||||||
|
"tool_call_id": "1",
|
||||||
|
"name": "{tool_name1}",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
expected = [
|
||||||
|
AIMessage(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "important message",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
name="foo",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"name": "do_stuff",
|
||||||
|
"args": {"arg1": "important arg1"},
|
||||||
|
"id": "1",
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
ToolMessage("foo", name="do_stuff", tool_call_id="1"),
|
||||||
|
]
|
||||||
|
|
||||||
|
actual = template.invoke(
|
||||||
|
{
|
||||||
|
"text1": "important message",
|
||||||
|
"name1": "foo",
|
||||||
|
"tool_arg1": "important arg1",
|
||||||
|
"tool_name1": "do_stuff",
|
||||||
|
"tool_content2": "foo",
|
||||||
|
}
|
||||||
|
).to_messages()
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
partial_ = template.partial(text1="important message")
|
||||||
|
actual = partial_.invoke(
|
||||||
|
{
|
||||||
|
"name1": "foo",
|
||||||
|
"tool_arg1": "important arg1",
|
||||||
|
"tool_name1": "do_stuff",
|
||||||
|
"tool_content2": "foo",
|
||||||
|
}
|
||||||
|
).to_messages()
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
def test_chat_prompt_template_variable_names() -> None:
|
def test_chat_prompt_template_variable_names() -> None:
|
||||||
"""This test was written for an edge case that triggers a warning from Pydantic.
|
"""This test was written for an edge case that triggers a warning from Pydantic.
|
||||||
|
|
||||||
@ -1049,3 +1135,87 @@ def test_chat_prompt_template_variable_names() -> None:
|
|||||||
"title": "PromptInput",
|
"title": "PromptInput",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_data_prompt_template_deserializable() -> None:
|
||||||
|
"""Test that the image prompt template is serializable."""
|
||||||
|
load(
|
||||||
|
dumpd(
|
||||||
|
ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"system",
|
||||||
|
[{"type": "image", "source_type": "url", "url": "{url}"}],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("jinja2")
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("template_format", "cache_control_placeholder", "source_data_placeholder"),
|
||||||
|
[
|
||||||
|
("f-string", "{cache_type}", "{source_data}"),
|
||||||
|
("mustache", "{{cache_type}}", "{{source_data}}"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_chat_prompt_template_data_prompt_from_message(
|
||||||
|
template_format: PromptTemplateFormat,
|
||||||
|
cache_control_placeholder: str,
|
||||||
|
source_data_placeholder: str,
|
||||||
|
) -> None:
|
||||||
|
prompt: dict = {
|
||||||
|
"type": "image",
|
||||||
|
"source_type": "base64",
|
||||||
|
"data": f"{source_data_placeholder}",
|
||||||
|
}
|
||||||
|
|
||||||
|
template = ChatPromptTemplate.from_messages(
|
||||||
|
[("human", [prompt])], template_format=template_format
|
||||||
|
)
|
||||||
|
assert template.format_messages(source_data="base64data") == [
|
||||||
|
HumanMessage(
|
||||||
|
content=[
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source_type": "base64",
|
||||||
|
"data": "base64data",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# metadata
|
||||||
|
prompt["metadata"] = {"cache_control": {"type": f"{cache_control_placeholder}"}}
|
||||||
|
template = ChatPromptTemplate.from_messages(
|
||||||
|
[("human", [prompt])], template_format=template_format
|
||||||
|
)
|
||||||
|
assert template.format_messages(
|
||||||
|
cache_type="ephemeral", source_data="base64data"
|
||||||
|
) == [
|
||||||
|
HumanMessage(
|
||||||
|
content=[
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source_type": "base64",
|
||||||
|
"data": "base64data",
|
||||||
|
"metadata": {"cache_control": {"type": "ephemeral"}},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_dict_message_prompt_template_errors_on_jinja2() -> None:
|
||||||
|
prompt = {
|
||||||
|
"type": "image",
|
||||||
|
"source_type": "base64",
|
||||||
|
"data": "{source_data}",
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="jinja2"):
|
||||||
|
_ = ChatPromptTemplate.from_messages(
|
||||||
|
[("human", [prompt])], template_format="jinja2"
|
||||||
|
)
|
||||||
|
61
libs/core/tests/unit_tests/prompts/test_message.py
Normal file
61
libs/core/tests/unit_tests/prompts/test_message.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage, BaseMessage, ToolMessage
|
||||||
|
from langchain_core.prompts.message import _DictMessagePromptTemplate
|
||||||
|
|
||||||
|
CUR_DIR = Path(__file__).parent.absolute().resolve()
|
||||||
|
|
||||||
|
|
||||||
|
def test__dict_message_prompt_template_fstring() -> None:
|
||||||
|
template = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "{text1}", "cache_control": {"type": "ephemeral"}},
|
||||||
|
],
|
||||||
|
"name": "{name1}",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"name": "{tool_name1}",
|
||||||
|
"args": {"arg1": "{tool_arg1}"},
|
||||||
|
"id": "1",
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
prompt = _DictMessagePromptTemplate(template=template, template_format="f-string")
|
||||||
|
expected: BaseMessage = AIMessage(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "important message",
|
||||||
|
"cache_control": {"type": "ephemeral"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
name="foo",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"name": "do_stuff",
|
||||||
|
"args": {"arg1": "important arg1"},
|
||||||
|
"id": "1",
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
actual = prompt.format_messages(
|
||||||
|
text1="important message",
|
||||||
|
name1="foo",
|
||||||
|
tool_arg1="important arg1",
|
||||||
|
tool_name1="do_stuff",
|
||||||
|
)[0]
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
template = {
|
||||||
|
"role": "tool",
|
||||||
|
"content": "{content1}",
|
||||||
|
"tool_call_id": "1",
|
||||||
|
"name": "{name1}",
|
||||||
|
}
|
||||||
|
prompt = _DictMessagePromptTemplate(template=template, template_format="f-string")
|
||||||
|
expected = ToolMessage("foo", name="bar", tool_call_id="1")
|
||||||
|
actual = prompt.format_messages(content1="foo", name1="bar")[0]
|
||||||
|
assert actual == expected
|
Loading…
Reference in New Issue
Block a user