add prompt template

This commit is contained in:
Chester Curme 2025-04-09 14:39:24 -04:00
parent 35fbe24532
commit 7e3caec720
5 changed files with 273 additions and 4 deletions

View File

@ -146,6 +146,12 @@ SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
"image",
"ImagePromptTemplate",
),
("langchain", "prompts", "data", "DataPromptTemplate"): (
"langchain_core",
"prompts",
"data",
"DataPromptTemplate",
),
("langchain", "schema", "agent", "AgentActionMessageLog"): (
"langchain_core",
"agents",

View File

@ -16,6 +16,7 @@ from langchain_core.load.serializable import Serializable
from langchain_core.messages import (
AnyMessage,
BaseMessage,
DataContentBlock,
HumanMessage,
get_buffer_string,
)
@ -130,6 +131,22 @@ class ImagePromptValue(PromptValue):
return [HumanMessage(content=[cast("dict", self.image_url)])]
class DataPromptValue(PromptValue):
"""Prompt value for multi-modal data."""
content_block: DataContentBlock
"""Multi-modal content block."""
type: Literal["DataPromptValue"] = "DataPromptValue"
def to_string(self) -> str:
"""Return source data as a string."""
return self.content_block["source"]
def to_messages(self) -> list[BaseMessage]:
"""Return prompt (image URL) as messages."""
return [HumanMessage(content=[cast("dict", self.content_block)])]
class ChatPromptValueConcrete(ChatPromptValue):
"""Chat prompt value which explicitly lists out the message types it accepts.

View File

@ -31,13 +31,16 @@ from langchain_core.messages import (
AnyMessage,
BaseMessage,
ChatMessage,
DataContentBlock,
HumanMessage,
SystemMessage,
convert_to_messages,
is_data_content_block,
)
from langchain_core.messages.base import get_msg_title_repr
from langchain_core.prompt_values import ChatPromptValue, ImageURL, PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.data import DataPromptTemplate
from langchain_core.prompts.image import ImagePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import (
@ -468,7 +471,8 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
"""Human message prompt template. This is a message sent from the user."""
prompt: Union[
StringPromptTemplate, list[Union[StringPromptTemplate, ImagePromptTemplate]]
StringPromptTemplate,
list[Union[StringPromptTemplate, ImagePromptTemplate, DataPromptTemplate]],
]
"""Prompt template."""
additional_kwargs: dict = Field(default_factory=dict)
@ -479,7 +483,10 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
@classmethod
def from_template(
cls: type[Self],
template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
template: Union[
str,
list[Union[str, _TextTemplateParam, _ImageTemplateParam, DataContentBlock]],
],
template_format: PromptTemplateFormat = "f-string",
*,
partial_variables: Optional[dict[str, Any]] = None,
@ -562,6 +569,23 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
msg = f"Invalid image template: {tmpl}"
raise ValueError(msg)
prompt.append(img_template_obj)
elif isinstance(tmpl, dict) and is_data_content_block(tmpl): # type: ignore[arg-type]
data_template = cast("DataContentBlock", tmpl)
input_variables = []
for key in ["source", "source_type", "mime_type"]:
if key in data_template:
input_variables.extend(
get_template_variables(
data_template[key], # type: ignore[literal-required]
template_format,
)
)
data_template_obj = DataPromptTemplate(
input_variables=input_variables,
template=data_template,
template_format=template_format,
)
prompt.append(data_template_obj)
else:
msg = f"Invalid template: {tmpl}"
raise ValueError(msg)
@ -639,11 +663,16 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
for prompt in self.prompt:
inputs = {var: kwargs[var] for var in prompt.input_variables}
if isinstance(prompt, StringPromptTemplate):
formatted: Union[str, ImageURL] = prompt.format(**inputs)
formatted: Union[str, ImageURL, DataContentBlock] = prompt.format(
**inputs
)
content.append({"type": "text", "text": formatted})
elif isinstance(prompt, ImagePromptTemplate):
formatted = prompt.format(**inputs)
content.append({"type": "image_url", "image_url": formatted})
elif isinstance(prompt, DataPromptTemplate):
formatted = prompt.format(**inputs)
content.append(formatted)
return self._msg_class(
content=content, additional_kwargs=self.additional_kwargs
)

View File

@ -0,0 +1,145 @@
"""Image prompt template for a multimodal model."""
from typing import Any, cast
from pydantic import Field
from langchain_core.messages import DataContentBlock
from langchain_core.prompt_values import DataPromptValue, PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,
PromptTemplateFormat,
)
from langchain_core.runnables import run_in_executor
class DataPromptTemplate(BasePromptTemplate[DataContentBlock]):
"""Prompt template for a multi-modal model."""
template: dict = Field(default_factory=dict)
"""Template for the prompt."""
template_format: PromptTemplateFormat = "f-string"
"""The format of the prompt template.
Options are: 'f-string', 'mustache', 'jinja2'."""
def __init__(self, **kwargs: Any) -> None:
"""Create a prompt template for multi-modal data."""
if "input_variables" not in kwargs:
kwargs["input_variables"] = []
overlap = set(kwargs["input_variables"]) & {
"source",
"source_type",
"mime_type",
"metadata",
}
if overlap:
msg = (
"input_variables for the template cannot contain"
" any of 'source', 'source_type', 'mime_type', or 'metadata'."
f" Found: {overlap}"
)
raise ValueError(msg)
super().__init__(**kwargs)
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
return "data-prompt"
@classmethod
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "data"]
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
"""
return DataPromptValue(content_block=self.format(**kwargs))
async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
"""Async format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
"""
return DataPromptValue(content_block=await self.aformat(**kwargs))
def format(
self,
**kwargs: Any,
) -> DataContentBlock:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Raises:
ValueError: If the url is not provided.
ValueError: If the url is not a string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
formatted = {}
for k, v in self.template.items():
if isinstance(v, str):
formatted[k] = DEFAULT_FORMATTER_MAPPING[self.template_format](
v, **kwargs
)
else:
formatted[k] = v
block = {}
for k in ["type", "source_type", "source", "mime_type", "metadata"]:
value = kwargs.get(k) or formatted.get(k)
if value:
block[k] = value
for required_field in ["source", "source_type"]:
if required_field not in block:
msg = f"Missing required field: {required_field}"
raise ValueError(msg)
return cast("DataContentBlock", block)
async def aformat(self, **kwargs: Any) -> DataContentBlock:
"""Async format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Raises:
ValueError: If the path or url is not a string.
"""
return await run_in_executor(None, self.format, **kwargs)
def pretty_repr(self, html: bool = False) -> str:
"""Return a pretty representation of the prompt.
Args:
html: Whether to return an html formatted string.
Returns:
A pretty representation of the prompt.
"""
raise NotImplementedError

View File

@ -11,7 +11,7 @@ from syrupy import SnapshotAssertion
from langchain_core._api.deprecation import (
LangChainPendingDeprecationWarning,
)
from langchain_core.load import dumpd, load
from langchain_core.load import dump, dumpd, load, loads
from langchain_core.messages import (
AIMessage,
BaseMessage,
@ -1049,3 +1049,75 @@ def test_chat_prompt_template_variable_names() -> None:
"title": "PromptInput",
"type": "object",
}
def test_data_prompt_template_deserializable() -> None:
"""Test that the image prompt template is serializable."""
loads(
dump.dumps(
ChatPromptTemplate.from_messages(
[
(
"system",
[{"type": "image", "source_type": "url", "source": "{url}"}],
)
]
)
)
)
@pytest.mark.requires("jinja2")
@pytest.mark.parametrize(
("template_format", "mime_type_placeholder", "source_data_placeholder"),
[
("f-string", "{media_type}", "{source_data}"),
("mustache", "{{media_type}}", "{{source_data}}"),
("jinja2", "{{ media_type }}", "{{ source_data }}"),
],
)
def test_chat_prompt_template_data_prompt_from_message(
template_format: PromptTemplateFormat,
mime_type_placeholder: str,
source_data_placeholder: str,
) -> None:
prompt = {
"type": "image",
"source_type": "base64",
"source": f"{source_data_placeholder}",
}
template = ChatPromptTemplate.from_messages(
[("human", [prompt])], template_format=template_format
)
assert template.format_messages(source_data="base64data") == [
HumanMessage(
content=[
{
"type": "image",
"source_type": "base64",
"source": "base64data",
}
]
)
]
# mime_type
prompt["mime_type"] = f"{mime_type_placeholder}"
template = ChatPromptTemplate.from_messages(
[("human", [prompt])], template_format=template_format
)
assert template.format_messages(
media_type="image/png", source_data="base64data"
) == [
HumanMessage(
content=[
{
"type": "image",
"source_type": "base64",
"source": "base64data",
"mime_type": "image/png",
}
]
)
]