diff --git a/libs/core/langchain_core/load/mapping.py b/libs/core/langchain_core/load/mapping.py index b4201c18b30..e93027ef4fa 100644 --- a/libs/core/langchain_core/load/mapping.py +++ b/libs/core/langchain_core/load/mapping.py @@ -146,6 +146,12 @@ SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = { "image", "ImagePromptTemplate", ), + ("langchain", "prompts", "data", "DataPromptTemplate"): ( + "langchain_core", + "prompts", + "data", + "DataPromptTemplate", + ), ("langchain", "schema", "agent", "AgentActionMessageLog"): ( "langchain_core", "agents", diff --git a/libs/core/langchain_core/prompt_values.py b/libs/core/langchain_core/prompt_values.py index 7652bd76e3c..f1073f11034 100644 --- a/libs/core/langchain_core/prompt_values.py +++ b/libs/core/langchain_core/prompt_values.py @@ -16,6 +16,7 @@ from langchain_core.load.serializable import Serializable from langchain_core.messages import ( AnyMessage, BaseMessage, + DataContentBlock, HumanMessage, get_buffer_string, ) @@ -130,6 +131,22 @@ class ImagePromptValue(PromptValue): return [HumanMessage(content=[cast("dict", self.image_url)])] +class DataPromptValue(PromptValue): + """Prompt value for multi-modal data.""" + + content_block: DataContentBlock + """Multi-modal content block.""" + type: Literal["DataPromptValue"] = "DataPromptValue" + + def to_string(self) -> str: + """Return source data as a string.""" + return self.content_block["source"] + + def to_messages(self) -> list[BaseMessage]: + """Return prompt (image URL) as messages.""" + return [HumanMessage(content=[cast("dict", self.content_block)])] + + class ChatPromptValueConcrete(ChatPromptValue): """Chat prompt value which explicitly lists out the message types it accepts. diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index 5f75df829f6..d9f10a922fd 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -31,13 +31,16 @@ from langchain_core.messages import ( AnyMessage, BaseMessage, ChatMessage, + DataContentBlock, HumanMessage, SystemMessage, convert_to_messages, + is_data_content_block, ) from langchain_core.messages.base import get_msg_title_repr from langchain_core.prompt_values import ChatPromptValue, ImageURL, PromptValue from langchain_core.prompts.base import BasePromptTemplate +from langchain_core.prompts.data import DataPromptTemplate from langchain_core.prompts.image import ImagePromptTemplate from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.string import ( @@ -468,7 +471,8 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): """Human message prompt template. This is a message sent from the user.""" prompt: Union[ - StringPromptTemplate, list[Union[StringPromptTemplate, ImagePromptTemplate]] + StringPromptTemplate, + list[Union[StringPromptTemplate, ImagePromptTemplate, DataPromptTemplate]], ] """Prompt template.""" additional_kwargs: dict = Field(default_factory=dict) @@ -479,7 +483,10 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): @classmethod def from_template( cls: type[Self], - template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]], + template: Union[ + str, + list[Union[str, _TextTemplateParam, _ImageTemplateParam, DataContentBlock]], + ], template_format: PromptTemplateFormat = "f-string", *, partial_variables: Optional[dict[str, Any]] = None, @@ -562,6 +569,23 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): msg = f"Invalid image template: {tmpl}" raise ValueError(msg) prompt.append(img_template_obj) + elif isinstance(tmpl, dict) and is_data_content_block(tmpl): # type: ignore[arg-type] + data_template = cast("DataContentBlock", tmpl) + input_variables = [] + for key in ["source", "source_type", "mime_type"]: + if key in data_template: + input_variables.extend( + get_template_variables( + data_template[key], # type: ignore[literal-required] + template_format, + ) + ) + data_template_obj = DataPromptTemplate( + input_variables=input_variables, + template=data_template, + template_format=template_format, + ) + prompt.append(data_template_obj) else: msg = f"Invalid template: {tmpl}" raise ValueError(msg) @@ -639,11 +663,16 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): for prompt in self.prompt: inputs = {var: kwargs[var] for var in prompt.input_variables} if isinstance(prompt, StringPromptTemplate): - formatted: Union[str, ImageURL] = prompt.format(**inputs) + formatted: Union[str, ImageURL, DataContentBlock] = prompt.format( + **inputs + ) content.append({"type": "text", "text": formatted}) elif isinstance(prompt, ImagePromptTemplate): formatted = prompt.format(**inputs) content.append({"type": "image_url", "image_url": formatted}) + elif isinstance(prompt, DataPromptTemplate): + formatted = prompt.format(**inputs) + content.append(formatted) return self._msg_class( content=content, additional_kwargs=self.additional_kwargs ) diff --git a/libs/core/langchain_core/prompts/data.py b/libs/core/langchain_core/prompts/data.py new file mode 100644 index 00000000000..e43ba452d34 --- /dev/null +++ b/libs/core/langchain_core/prompts/data.py @@ -0,0 +1,145 @@ +"""Image prompt template for a multimodal model.""" + +from typing import Any, cast + +from pydantic import Field + +from langchain_core.messages import DataContentBlock +from langchain_core.prompt_values import DataPromptValue, PromptValue +from langchain_core.prompts.base import BasePromptTemplate +from langchain_core.prompts.string import ( + DEFAULT_FORMATTER_MAPPING, + PromptTemplateFormat, +) +from langchain_core.runnables import run_in_executor + + +class DataPromptTemplate(BasePromptTemplate[DataContentBlock]): + """Prompt template for a multi-modal model.""" + + template: dict = Field(default_factory=dict) + """Template for the prompt.""" + template_format: PromptTemplateFormat = "f-string" + """The format of the prompt template. + Options are: 'f-string', 'mustache', 'jinja2'.""" + + def __init__(self, **kwargs: Any) -> None: + """Create a prompt template for multi-modal data.""" + if "input_variables" not in kwargs: + kwargs["input_variables"] = [] + + overlap = set(kwargs["input_variables"]) & { + "source", + "source_type", + "mime_type", + "metadata", + } + if overlap: + msg = ( + "input_variables for the template cannot contain" + " any of 'source', 'source_type', 'mime_type', or 'metadata'." + f" Found: {overlap}" + ) + raise ValueError(msg) + super().__init__(**kwargs) + + @property + def _prompt_type(self) -> str: + """Return the prompt type key.""" + return "data-prompt" + + @classmethod + def get_lc_namespace(cls) -> list[str]: + """Get the namespace of the langchain object.""" + return ["langchain", "prompts", "data"] + + def format_prompt(self, **kwargs: Any) -> PromptValue: + """Format the prompt with the inputs. + + Args: + kwargs: Any arguments to be passed to the prompt template. + + Returns: + A formatted string. + """ + return DataPromptValue(content_block=self.format(**kwargs)) + + async def aformat_prompt(self, **kwargs: Any) -> PromptValue: + """Async format the prompt with the inputs. + + Args: + kwargs: Any arguments to be passed to the prompt template. + + Returns: + A formatted string. + """ + return DataPromptValue(content_block=await self.aformat(**kwargs)) + + def format( + self, + **kwargs: Any, + ) -> DataContentBlock: + """Format the prompt with the inputs. + + Args: + kwargs: Any arguments to be passed to the prompt template. + + Returns: + A formatted string. + + Raises: + ValueError: If the url is not provided. + ValueError: If the url is not a string. + + Example: + + .. code-block:: python + + prompt.format(variable1="foo") + """ + formatted = {} + for k, v in self.template.items(): + if isinstance(v, str): + formatted[k] = DEFAULT_FORMATTER_MAPPING[self.template_format]( + v, **kwargs + ) + else: + formatted[k] = v + + block = {} + for k in ["type", "source_type", "source", "mime_type", "metadata"]: + value = kwargs.get(k) or formatted.get(k) + if value: + block[k] = value + + for required_field in ["source", "source_type"]: + if required_field not in block: + msg = f"Missing required field: {required_field}" + raise ValueError(msg) + + return cast("DataContentBlock", block) + + async def aformat(self, **kwargs: Any) -> DataContentBlock: + """Async format the prompt with the inputs. + + Args: + kwargs: Any arguments to be passed to the prompt template. + + Returns: + A formatted string. + + Raises: + ValueError: If the path or url is not a string. + """ + return await run_in_executor(None, self.format, **kwargs) + + def pretty_repr(self, html: bool = False) -> str: + """Return a pretty representation of the prompt. + + Args: + html: Whether to return an html formatted string. + + Returns: + A pretty representation of the prompt. + """ + raise NotImplementedError diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index 84a83c6ae2d..dc063e98d5d 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -11,7 +11,7 @@ from syrupy import SnapshotAssertion from langchain_core._api.deprecation import ( LangChainPendingDeprecationWarning, ) -from langchain_core.load import dumpd, load +from langchain_core.load import dump, dumpd, load, loads from langchain_core.messages import ( AIMessage, BaseMessage, @@ -1049,3 +1049,75 @@ def test_chat_prompt_template_variable_names() -> None: "title": "PromptInput", "type": "object", } + + +def test_data_prompt_template_deserializable() -> None: + """Test that the image prompt template is serializable.""" + loads( + dump.dumps( + ChatPromptTemplate.from_messages( + [ + ( + "system", + [{"type": "image", "source_type": "url", "source": "{url}"}], + ) + ] + ) + ) + ) + + +@pytest.mark.requires("jinja2") +@pytest.mark.parametrize( + ("template_format", "mime_type_placeholder", "source_data_placeholder"), + [ + ("f-string", "{media_type}", "{source_data}"), + ("mustache", "{{media_type}}", "{{source_data}}"), + ("jinja2", "{{ media_type }}", "{{ source_data }}"), + ], +) +def test_chat_prompt_template_data_prompt_from_message( + template_format: PromptTemplateFormat, + mime_type_placeholder: str, + source_data_placeholder: str, +) -> None: + prompt = { + "type": "image", + "source_type": "base64", + "source": f"{source_data_placeholder}", + } + + template = ChatPromptTemplate.from_messages( + [("human", [prompt])], template_format=template_format + ) + assert template.format_messages(source_data="base64data") == [ + HumanMessage( + content=[ + { + "type": "image", + "source_type": "base64", + "source": "base64data", + } + ] + ) + ] + + # mime_type + prompt["mime_type"] = f"{mime_type_placeholder}" + template = ChatPromptTemplate.from_messages( + [("human", [prompt])], template_format=template_format + ) + assert template.format_messages( + media_type="image/png", source_data="base64data" + ) == [ + HumanMessage( + content=[ + { + "type": "image", + "source_type": "base64", + "source": "base64data", + "mime_type": "image/png", + } + ] + ) + ]