mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 08:58:48 +00:00
add prompt template
This commit is contained in:
parent
35fbe24532
commit
7e3caec720
@ -146,6 +146,12 @@ SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
|
|||||||
"image",
|
"image",
|
||||||
"ImagePromptTemplate",
|
"ImagePromptTemplate",
|
||||||
),
|
),
|
||||||
|
("langchain", "prompts", "data", "DataPromptTemplate"): (
|
||||||
|
"langchain_core",
|
||||||
|
"prompts",
|
||||||
|
"data",
|
||||||
|
"DataPromptTemplate",
|
||||||
|
),
|
||||||
("langchain", "schema", "agent", "AgentActionMessageLog"): (
|
("langchain", "schema", "agent", "AgentActionMessageLog"): (
|
||||||
"langchain_core",
|
"langchain_core",
|
||||||
"agents",
|
"agents",
|
||||||
|
@ -16,6 +16,7 @@ from langchain_core.load.serializable import Serializable
|
|||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AnyMessage,
|
AnyMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
|
DataContentBlock,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
get_buffer_string,
|
get_buffer_string,
|
||||||
)
|
)
|
||||||
@ -130,6 +131,22 @@ class ImagePromptValue(PromptValue):
|
|||||||
return [HumanMessage(content=[cast("dict", self.image_url)])]
|
return [HumanMessage(content=[cast("dict", self.image_url)])]
|
||||||
|
|
||||||
|
|
||||||
|
class DataPromptValue(PromptValue):
|
||||||
|
"""Prompt value for multi-modal data."""
|
||||||
|
|
||||||
|
content_block: DataContentBlock
|
||||||
|
"""Multi-modal content block."""
|
||||||
|
type: Literal["DataPromptValue"] = "DataPromptValue"
|
||||||
|
|
||||||
|
def to_string(self) -> str:
|
||||||
|
"""Return source data as a string."""
|
||||||
|
return self.content_block["source"]
|
||||||
|
|
||||||
|
def to_messages(self) -> list[BaseMessage]:
|
||||||
|
"""Return prompt (image URL) as messages."""
|
||||||
|
return [HumanMessage(content=[cast("dict", self.content_block)])]
|
||||||
|
|
||||||
|
|
||||||
class ChatPromptValueConcrete(ChatPromptValue):
|
class ChatPromptValueConcrete(ChatPromptValue):
|
||||||
"""Chat prompt value which explicitly lists out the message types it accepts.
|
"""Chat prompt value which explicitly lists out the message types it accepts.
|
||||||
|
|
||||||
|
@ -31,13 +31,16 @@ from langchain_core.messages import (
|
|||||||
AnyMessage,
|
AnyMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
|
DataContentBlock,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
convert_to_messages,
|
convert_to_messages,
|
||||||
|
is_data_content_block,
|
||||||
)
|
)
|
||||||
from langchain_core.messages.base import get_msg_title_repr
|
from langchain_core.messages.base import get_msg_title_repr
|
||||||
from langchain_core.prompt_values import ChatPromptValue, ImageURL, PromptValue
|
from langchain_core.prompt_values import ChatPromptValue, ImageURL, PromptValue
|
||||||
from langchain_core.prompts.base import BasePromptTemplate
|
from langchain_core.prompts.base import BasePromptTemplate
|
||||||
|
from langchain_core.prompts.data import DataPromptTemplate
|
||||||
from langchain_core.prompts.image import ImagePromptTemplate
|
from langchain_core.prompts.image import ImagePromptTemplate
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
from langchain_core.prompts.string import (
|
from langchain_core.prompts.string import (
|
||||||
@ -468,7 +471,8 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
"""Human message prompt template. This is a message sent from the user."""
|
"""Human message prompt template. This is a message sent from the user."""
|
||||||
|
|
||||||
prompt: Union[
|
prompt: Union[
|
||||||
StringPromptTemplate, list[Union[StringPromptTemplate, ImagePromptTemplate]]
|
StringPromptTemplate,
|
||||||
|
list[Union[StringPromptTemplate, ImagePromptTemplate, DataPromptTemplate]],
|
||||||
]
|
]
|
||||||
"""Prompt template."""
|
"""Prompt template."""
|
||||||
additional_kwargs: dict = Field(default_factory=dict)
|
additional_kwargs: dict = Field(default_factory=dict)
|
||||||
@ -479,7 +483,10 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_template(
|
def from_template(
|
||||||
cls: type[Self],
|
cls: type[Self],
|
||||||
template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
|
template: Union[
|
||||||
|
str,
|
||||||
|
list[Union[str, _TextTemplateParam, _ImageTemplateParam, DataContentBlock]],
|
||||||
|
],
|
||||||
template_format: PromptTemplateFormat = "f-string",
|
template_format: PromptTemplateFormat = "f-string",
|
||||||
*,
|
*,
|
||||||
partial_variables: Optional[dict[str, Any]] = None,
|
partial_variables: Optional[dict[str, Any]] = None,
|
||||||
@ -562,6 +569,23 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
msg = f"Invalid image template: {tmpl}"
|
msg = f"Invalid image template: {tmpl}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
prompt.append(img_template_obj)
|
prompt.append(img_template_obj)
|
||||||
|
elif isinstance(tmpl, dict) and is_data_content_block(tmpl): # type: ignore[arg-type]
|
||||||
|
data_template = cast("DataContentBlock", tmpl)
|
||||||
|
input_variables = []
|
||||||
|
for key in ["source", "source_type", "mime_type"]:
|
||||||
|
if key in data_template:
|
||||||
|
input_variables.extend(
|
||||||
|
get_template_variables(
|
||||||
|
data_template[key], # type: ignore[literal-required]
|
||||||
|
template_format,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
data_template_obj = DataPromptTemplate(
|
||||||
|
input_variables=input_variables,
|
||||||
|
template=data_template,
|
||||||
|
template_format=template_format,
|
||||||
|
)
|
||||||
|
prompt.append(data_template_obj)
|
||||||
else:
|
else:
|
||||||
msg = f"Invalid template: {tmpl}"
|
msg = f"Invalid template: {tmpl}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
@ -639,11 +663,16 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
for prompt in self.prompt:
|
for prompt in self.prompt:
|
||||||
inputs = {var: kwargs[var] for var in prompt.input_variables}
|
inputs = {var: kwargs[var] for var in prompt.input_variables}
|
||||||
if isinstance(prompt, StringPromptTemplate):
|
if isinstance(prompt, StringPromptTemplate):
|
||||||
formatted: Union[str, ImageURL] = prompt.format(**inputs)
|
formatted: Union[str, ImageURL, DataContentBlock] = prompt.format(
|
||||||
|
**inputs
|
||||||
|
)
|
||||||
content.append({"type": "text", "text": formatted})
|
content.append({"type": "text", "text": formatted})
|
||||||
elif isinstance(prompt, ImagePromptTemplate):
|
elif isinstance(prompt, ImagePromptTemplate):
|
||||||
formatted = prompt.format(**inputs)
|
formatted = prompt.format(**inputs)
|
||||||
content.append({"type": "image_url", "image_url": formatted})
|
content.append({"type": "image_url", "image_url": formatted})
|
||||||
|
elif isinstance(prompt, DataPromptTemplate):
|
||||||
|
formatted = prompt.format(**inputs)
|
||||||
|
content.append(formatted)
|
||||||
return self._msg_class(
|
return self._msg_class(
|
||||||
content=content, additional_kwargs=self.additional_kwargs
|
content=content, additional_kwargs=self.additional_kwargs
|
||||||
)
|
)
|
||||||
|
145
libs/core/langchain_core/prompts/data.py
Normal file
145
libs/core/langchain_core/prompts/data.py
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
"""Image prompt template for a multimodal model."""
|
||||||
|
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from langchain_core.messages import DataContentBlock
|
||||||
|
from langchain_core.prompt_values import DataPromptValue, PromptValue
|
||||||
|
from langchain_core.prompts.base import BasePromptTemplate
|
||||||
|
from langchain_core.prompts.string import (
|
||||||
|
DEFAULT_FORMATTER_MAPPING,
|
||||||
|
PromptTemplateFormat,
|
||||||
|
)
|
||||||
|
from langchain_core.runnables import run_in_executor
|
||||||
|
|
||||||
|
|
||||||
|
class DataPromptTemplate(BasePromptTemplate[DataContentBlock]):
|
||||||
|
"""Prompt template for a multi-modal model."""
|
||||||
|
|
||||||
|
template: dict = Field(default_factory=dict)
|
||||||
|
"""Template for the prompt."""
|
||||||
|
template_format: PromptTemplateFormat = "f-string"
|
||||||
|
"""The format of the prompt template.
|
||||||
|
Options are: 'f-string', 'mustache', 'jinja2'."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
"""Create a prompt template for multi-modal data."""
|
||||||
|
if "input_variables" not in kwargs:
|
||||||
|
kwargs["input_variables"] = []
|
||||||
|
|
||||||
|
overlap = set(kwargs["input_variables"]) & {
|
||||||
|
"source",
|
||||||
|
"source_type",
|
||||||
|
"mime_type",
|
||||||
|
"metadata",
|
||||||
|
}
|
||||||
|
if overlap:
|
||||||
|
msg = (
|
||||||
|
"input_variables for the template cannot contain"
|
||||||
|
" any of 'source', 'source_type', 'mime_type', or 'metadata'."
|
||||||
|
f" Found: {overlap}"
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _prompt_type(self) -> str:
|
||||||
|
"""Return the prompt type key."""
|
||||||
|
return "data-prompt"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_lc_namespace(cls) -> list[str]:
|
||||||
|
"""Get the namespace of the langchain object."""
|
||||||
|
return ["langchain", "prompts", "data"]
|
||||||
|
|
||||||
|
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||||
|
"""Format the prompt with the inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kwargs: Any arguments to be passed to the prompt template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A formatted string.
|
||||||
|
"""
|
||||||
|
return DataPromptValue(content_block=self.format(**kwargs))
|
||||||
|
|
||||||
|
async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
|
||||||
|
"""Async format the prompt with the inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kwargs: Any arguments to be passed to the prompt template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A formatted string.
|
||||||
|
"""
|
||||||
|
return DataPromptValue(content_block=await self.aformat(**kwargs))
|
||||||
|
|
||||||
|
def format(
|
||||||
|
self,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> DataContentBlock:
|
||||||
|
"""Format the prompt with the inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kwargs: Any arguments to be passed to the prompt template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A formatted string.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the url is not provided.
|
||||||
|
ValueError: If the url is not a string.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
prompt.format(variable1="foo")
|
||||||
|
"""
|
||||||
|
formatted = {}
|
||||||
|
for k, v in self.template.items():
|
||||||
|
if isinstance(v, str):
|
||||||
|
formatted[k] = DEFAULT_FORMATTER_MAPPING[self.template_format](
|
||||||
|
v, **kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
formatted[k] = v
|
||||||
|
|
||||||
|
block = {}
|
||||||
|
for k in ["type", "source_type", "source", "mime_type", "metadata"]:
|
||||||
|
value = kwargs.get(k) or formatted.get(k)
|
||||||
|
if value:
|
||||||
|
block[k] = value
|
||||||
|
|
||||||
|
for required_field in ["source", "source_type"]:
|
||||||
|
if required_field not in block:
|
||||||
|
msg = f"Missing required field: {required_field}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
return cast("DataContentBlock", block)
|
||||||
|
|
||||||
|
async def aformat(self, **kwargs: Any) -> DataContentBlock:
|
||||||
|
"""Async format the prompt with the inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kwargs: Any arguments to be passed to the prompt template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A formatted string.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the path or url is not a string.
|
||||||
|
"""
|
||||||
|
return await run_in_executor(None, self.format, **kwargs)
|
||||||
|
|
||||||
|
def pretty_repr(self, html: bool = False) -> str:
|
||||||
|
"""Return a pretty representation of the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
html: Whether to return an html formatted string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A pretty representation of the prompt.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
@ -11,7 +11,7 @@ from syrupy import SnapshotAssertion
|
|||||||
from langchain_core._api.deprecation import (
|
from langchain_core._api.deprecation import (
|
||||||
LangChainPendingDeprecationWarning,
|
LangChainPendingDeprecationWarning,
|
||||||
)
|
)
|
||||||
from langchain_core.load import dumpd, load
|
from langchain_core.load import dump, dumpd, load, loads
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
@ -1049,3 +1049,75 @@ def test_chat_prompt_template_variable_names() -> None:
|
|||||||
"title": "PromptInput",
|
"title": "PromptInput",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_data_prompt_template_deserializable() -> None:
|
||||||
|
"""Test that the image prompt template is serializable."""
|
||||||
|
loads(
|
||||||
|
dump.dumps(
|
||||||
|
ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"system",
|
||||||
|
[{"type": "image", "source_type": "url", "source": "{url}"}],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("jinja2")
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("template_format", "mime_type_placeholder", "source_data_placeholder"),
|
||||||
|
[
|
||||||
|
("f-string", "{media_type}", "{source_data}"),
|
||||||
|
("mustache", "{{media_type}}", "{{source_data}}"),
|
||||||
|
("jinja2", "{{ media_type }}", "{{ source_data }}"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_chat_prompt_template_data_prompt_from_message(
|
||||||
|
template_format: PromptTemplateFormat,
|
||||||
|
mime_type_placeholder: str,
|
||||||
|
source_data_placeholder: str,
|
||||||
|
) -> None:
|
||||||
|
prompt = {
|
||||||
|
"type": "image",
|
||||||
|
"source_type": "base64",
|
||||||
|
"source": f"{source_data_placeholder}",
|
||||||
|
}
|
||||||
|
|
||||||
|
template = ChatPromptTemplate.from_messages(
|
||||||
|
[("human", [prompt])], template_format=template_format
|
||||||
|
)
|
||||||
|
assert template.format_messages(source_data="base64data") == [
|
||||||
|
HumanMessage(
|
||||||
|
content=[
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source_type": "base64",
|
||||||
|
"source": "base64data",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# mime_type
|
||||||
|
prompt["mime_type"] = f"{mime_type_placeholder}"
|
||||||
|
template = ChatPromptTemplate.from_messages(
|
||||||
|
[("human", [prompt])], template_format=template_format
|
||||||
|
)
|
||||||
|
assert template.format_messages(
|
||||||
|
media_type="image/png", source_data="base64data"
|
||||||
|
) == [
|
||||||
|
HumanMessage(
|
||||||
|
content=[
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source_type": "base64",
|
||||||
|
"source": "base64data",
|
||||||
|
"mime_type": "image/png",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user