core: fix Image prompt template hardcoded template format (#27495)

Fixes #27411 

**Description:** Adds `template_format` to the `ImagePromptTemplate`
class and updates passing in the `template_format` parameter from
ChatPromptTemplate instead of the hardcoded "f-string".
Also updated docs and typing related to `template_format` to be more
up-to-date and specific.

**Dependencies:** None

**Add tests and docs**: Added unit tests to validate fix. Needed to
update `test_chat` snapshot due to adding new attribute
`template_format` in `ImagePromptTemplate`.

---------

Co-authored-by: Vadym Barda <vadym@langchain.dev>
This commit is contained in:
Chun Kang Lu 2024-10-21 17:31:40 -04:00 committed by GitHub
parent 403c0ea801
commit 380449a7a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 131 additions and 25 deletions

View File

@ -1 +1,2 @@
jinja2>=3,<4 jinja2>=3,<4
mustache>=0.1.4,<1

View File

@ -8,7 +8,6 @@ from pathlib import Path
from typing import ( from typing import (
Annotated, Annotated,
Any, Any,
Literal,
Optional, Optional,
TypedDict, TypedDict,
TypeVar, TypeVar,
@ -40,7 +39,11 @@ from langchain_core.prompt_values import ChatPromptValue, ImageURL, PromptValue
from langchain_core.prompts.base import BasePromptTemplate from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.image import ImagePromptTemplate from langchain_core.prompts.image import ImagePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import StringPromptTemplate, get_template_variables from langchain_core.prompts.string import (
PromptTemplateFormat,
StringPromptTemplate,
get_template_variables,
)
from langchain_core.utils import get_colored_text from langchain_core.utils import get_colored_text
from langchain_core.utils.interactive_env import is_interactive_env from langchain_core.utils.interactive_env import is_interactive_env
@ -296,7 +299,7 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
def from_template( def from_template(
cls: type[MessagePromptTemplateT], cls: type[MessagePromptTemplateT],
template: str, template: str,
template_format: str = "f-string", template_format: PromptTemplateFormat = "f-string",
partial_variables: Optional[dict[str, Any]] = None, partial_variables: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> MessagePromptTemplateT: ) -> MessagePromptTemplateT:
@ -486,7 +489,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
def from_template( def from_template(
cls: type[_StringImageMessagePromptTemplateT], cls: type[_StringImageMessagePromptTemplateT],
template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]], template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
template_format: str = "f-string", template_format: PromptTemplateFormat = "f-string",
*, *,
partial_variables: Optional[dict[str, Any]] = None, partial_variables: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
@ -495,7 +498,8 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
Args: Args:
template: a template. template: a template.
template_format: format of the template. Defaults to "f-string". template_format: format of the template.
Options are: 'f-string', 'mustache', 'jinja2'. Defaults to "f-string".
partial_variables: A dictionary of variables that can be used too partially. partial_variables: A dictionary of variables that can be used too partially.
Defaults to None. Defaults to None.
**kwargs: keyword arguments to pass to the constructor. **kwargs: keyword arguments to pass to the constructor.
@ -533,7 +537,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
img_template = cast(_ImageTemplateParam, tmpl)["image_url"] img_template = cast(_ImageTemplateParam, tmpl)["image_url"]
input_variables = [] input_variables = []
if isinstance(img_template, str): if isinstance(img_template, str):
vars = get_template_variables(img_template, "f-string") vars = get_template_variables(img_template, template_format)
if vars: if vars:
if len(vars) > 1: if len(vars) > 1:
msg = ( msg = (
@ -545,7 +549,9 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
input_variables = [vars[0]] input_variables = [vars[0]]
img_template = {"url": img_template} img_template = {"url": img_template}
img_template_obj = ImagePromptTemplate( img_template_obj = ImagePromptTemplate(
input_variables=input_variables, template=img_template input_variables=input_variables,
template=img_template,
template_format=template_format,
) )
elif isinstance(img_template, dict): elif isinstance(img_template, dict):
img_template = dict(img_template) img_template = dict(img_template)
@ -553,11 +559,13 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
if key in img_template: if key in img_template:
input_variables.extend( input_variables.extend(
get_template_variables( get_template_variables(
img_template[key], "f-string" img_template[key], template_format
) )
) )
img_template_obj = ImagePromptTemplate( img_template_obj = ImagePromptTemplate(
input_variables=input_variables, template=img_template input_variables=input_variables,
template=img_template,
template_format=template_format,
) )
else: else:
msg = f"Invalid image template: {tmpl}" msg = f"Invalid image template: {tmpl}"
@ -943,7 +951,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
self, self,
messages: Sequence[MessageLikeRepresentation], messages: Sequence[MessageLikeRepresentation],
*, *,
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string", template_format: PromptTemplateFormat = "f-string",
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Create a chat prompt template from a variety of message formats. """Create a chat prompt template from a variety of message formats.
@ -1160,7 +1168,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
def from_messages( def from_messages(
cls, cls,
messages: Sequence[MessageLikeRepresentation], messages: Sequence[MessageLikeRepresentation],
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string", template_format: PromptTemplateFormat = "f-string",
) -> ChatPromptTemplate: ) -> ChatPromptTemplate:
"""Create a chat prompt template from a variety of message formats. """Create a chat prompt template from a variety of message formats.
@ -1354,7 +1362,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
def _create_template_from_message_type( def _create_template_from_message_type(
message_type: str, message_type: str,
template: Union[str, list], template: Union[str, list],
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string", template_format: PromptTemplateFormat = "f-string",
) -> BaseMessagePromptTemplate: ) -> BaseMessagePromptTemplate:
"""Create a message prompt template from a message type and template string. """Create a message prompt template from a message type and template string.
@ -1426,7 +1434,7 @@ def _create_template_from_message_type(
def _convert_to_message( def _convert_to_message(
message: MessageLikeRepresentation, message: MessageLikeRepresentation,
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string", template_format: PromptTemplateFormat = "f-string",
) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]: ) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
"""Instantiate a message from a variety of message formats. """Instantiate a message from a variety of message formats.

View File

@ -9,6 +9,7 @@ from typing_extensions import Self
from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import ( from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING, DEFAULT_FORMATTER_MAPPING,
PromptTemplateFormat,
StringPromptTemplate, StringPromptTemplate,
) )
@ -36,8 +37,9 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
prefix: Optional[StringPromptTemplate] = None prefix: Optional[StringPromptTemplate] = None
"""A PromptTemplate to put before the examples.""" """A PromptTemplate to put before the examples."""
template_format: str = "f-string" template_format: PromptTemplateFormat = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'.""" """The format of the prompt template.
Options are: 'f-string', 'jinja2', 'mustache'."""
validate_template: bool = False validate_template: bool = False
"""Whether or not to try validating the template.""" """Whether or not to try validating the template."""

View File

@ -4,6 +4,10 @@ from pydantic import Field
from langchain_core.prompt_values import ImagePromptValue, ImageURL, PromptValue from langchain_core.prompt_values import ImagePromptValue, ImageURL, PromptValue
from langchain_core.prompts.base import BasePromptTemplate from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,
PromptTemplateFormat,
)
from langchain_core.runnables import run_in_executor from langchain_core.runnables import run_in_executor
from langchain_core.utils import image as image_utils from langchain_core.utils import image as image_utils
@ -13,6 +17,9 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
template: dict = Field(default_factory=dict) template: dict = Field(default_factory=dict)
"""Template for the prompt.""" """Template for the prompt."""
template_format: PromptTemplateFormat = "f-string"
"""The format of the prompt template.
Options are: 'f-string', 'mustache', 'jinja2'."""
def __init__(self, **kwargs: Any) -> None: def __init__(self, **kwargs: Any) -> None:
if "input_variables" not in kwargs: if "input_variables" not in kwargs:
@ -85,7 +92,9 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
formatted = {} formatted = {}
for k, v in self.template.items(): for k, v in self.template.items():
if isinstance(v, str): if isinstance(v, str):
formatted[k] = v.format(**kwargs) formatted[k] = DEFAULT_FORMATTER_MAPPING[self.template_format](
v, **kwargs
)
else: else:
formatted[k] = v formatted[k] = v
url = kwargs.get("url") or formatted.get("url") url = kwargs.get("url") or formatted.get("url")

View File

@ -4,12 +4,13 @@ from __future__ import annotations
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Any, Literal, Optional, Union from typing import Any, Optional, Union
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from langchain_core.prompts.string import ( from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING, DEFAULT_FORMATTER_MAPPING,
PromptTemplateFormat,
StringPromptTemplate, StringPromptTemplate,
check_valid_template, check_valid_template,
get_template_variables, get_template_variables,
@ -24,7 +25,8 @@ class PromptTemplate(StringPromptTemplate):
A prompt template consists of a string template. It accepts a set of parameters A prompt template consists of a string template. It accepts a set of parameters
from the user that can be used to generate a prompt for a language model. from the user that can be used to generate a prompt for a language model.
The template can be formatted using either f-strings (default) or jinja2 syntax. The template can be formatted using either f-strings (default), jinja2,
or mustache syntax.
*Security warning*: *Security warning*:
Prefer using `template_format="f-string"` instead of Prefer using `template_format="f-string"` instead of
@ -67,7 +69,7 @@ class PromptTemplate(StringPromptTemplate):
template: str template: str
"""The prompt template.""" """The prompt template."""
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string" template_format: PromptTemplateFormat = "f-string"
"""The format of the prompt template. """The format of the prompt template.
Options are: 'f-string', 'mustache', 'jinja2'.""" Options are: 'f-string', 'mustache', 'jinja2'."""
@ -248,7 +250,7 @@ class PromptTemplate(StringPromptTemplate):
cls, cls,
template: str, template: str,
*, *,
template_format: str = "f-string", template_format: PromptTemplateFormat = "f-string",
partial_variables: Optional[dict[str, Any]] = None, partial_variables: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> PromptTemplate: ) -> PromptTemplate:
@ -270,7 +272,7 @@ class PromptTemplate(StringPromptTemplate):
Args: Args:
template: The template to load. template: The template to load.
template_format: The format of the template. Use `jinja2` for jinja2, template_format: The format of the template. Use `jinja2` for jinja2,
and `f-string` or None for f-strings. `mustache` for mustache, and `f-string` for f-strings.
Defaults to `f-string`. Defaults to `f-string`.
partial_variables: A dictionary of variables that can be used to partially partial_variables: A dictionary of variables that can be used to partially
fill in the template. For example, if the template is fill in the template. For example, if the template is

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import warnings import warnings
from abc import ABC from abc import ABC
from string import Formatter from string import Formatter
from typing import Any, Callable from typing import Any, Callable, Literal
from pydantic import BaseModel, create_model from pydantic import BaseModel, create_model
@ -16,6 +16,8 @@ from langchain_core.utils import get_colored_text
from langchain_core.utils.formatting import formatter from langchain_core.utils.formatting import formatter
from langchain_core.utils.interactive_env import is_interactive_env from langchain_core.utils.interactive_env import is_interactive_env
PromptTemplateFormat = Literal["f-string", "mustache", "jinja2"]
def jinja2_formatter(template: str, /, **kwargs: Any) -> str: def jinja2_formatter(template: str, /, **kwargs: Any) -> str:
"""Format a template using jinja2. """Format a template using jinja2.

View File

@ -2,7 +2,6 @@ from collections.abc import Iterator, Mapping, Sequence
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Literal,
Optional, Optional,
Union, Union,
) )
@ -15,6 +14,7 @@ from langchain_core.prompts.chat import (
ChatPromptTemplate, ChatPromptTemplate,
MessageLikeRepresentation, MessageLikeRepresentation,
) )
from langchain_core.prompts.string import PromptTemplateFormat
from langchain_core.runnables.base import ( from langchain_core.runnables.base import (
Other, Other,
Runnable, Runnable,
@ -38,7 +38,7 @@ class StructuredPrompt(ChatPromptTemplate):
schema_: Optional[Union[dict, type[BaseModel]]] = None, schema_: Optional[Union[dict, type[BaseModel]]] = None,
*, *,
structured_output_kwargs: Optional[dict[str, Any]] = None, structured_output_kwargs: Optional[dict[str, Any]] = None,
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string", template_format: PromptTemplateFormat = "f-string",
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
schema_ = schema_ or kwargs.pop("schema") schema_ = schema_ or kwargs.pop("schema")

View File

@ -3119,6 +3119,7 @@
'template': dict({ 'template': dict({
'url': 'data:image/jpeg;base64,{my_image}', 'url': 'data:image/jpeg;base64,{my_image}',
}), }),
'template_format': 'f-string',
}), }),
'lc': 1, 'lc': 1,
'name': 'ImagePromptTemplate', 'name': 'ImagePromptTemplate',
@ -3138,6 +3139,7 @@
'template': dict({ 'template': dict({
'url': 'data:image/jpeg;base64,{my_image}', 'url': 'data:image/jpeg;base64,{my_image}',
}), }),
'template_format': 'f-string',
}), }),
'lc': 1, 'lc': 1,
'name': 'ImagePromptTemplate', 'name': 'ImagePromptTemplate',
@ -3157,6 +3159,7 @@
'template': dict({ 'template': dict({
'url': '{my_other_image}', 'url': '{my_other_image}',
}), }),
'template_format': 'f-string',
}), }),
'lc': 1, 'lc': 1,
'name': 'ImagePromptTemplate', 'name': 'ImagePromptTemplate',
@ -3177,6 +3180,7 @@
'detail': 'medium', 'detail': 'medium',
'url': '{my_other_image}', 'url': '{my_other_image}',
}), }),
'template_format': 'f-string',
}), }),
'lc': 1, 'lc': 1,
'name': 'ImagePromptTemplate', 'name': 'ImagePromptTemplate',
@ -3195,6 +3199,7 @@
'template': dict({ 'template': dict({
'url': 'https://www.langchain.com/image.png', 'url': 'https://www.langchain.com/image.png',
}), }),
'template_format': 'f-string',
}), }),
'lc': 1, 'lc': 1,
'name': 'ImagePromptTemplate', 'name': 'ImagePromptTemplate',
@ -3213,6 +3218,7 @@
'template': dict({ 'template': dict({
'url': 'data:image/jpeg;base64,foobar', 'url': 'data:image/jpeg;base64,foobar',
}), }),
'template_format': 'f-string',
}), }),
'lc': 1, 'lc': 1,
'name': 'ImagePromptTemplate', 'name': 'ImagePromptTemplate',
@ -3231,6 +3237,7 @@
'template': dict({ 'template': dict({
'url': 'data:image/jpeg;base64,foobar', 'url': 'data:image/jpeg;base64,foobar',
}), }),
'template_format': 'f-string',
}), }),
'lc': 1, 'lc': 1,
'name': 'ImagePromptTemplate', 'name': 'ImagePromptTemplate',

View File

@ -31,6 +31,7 @@ from langchain_core.prompts.chat import (
SystemMessagePromptTemplate, SystemMessagePromptTemplate,
_convert_to_message, _convert_to_message,
) )
from langchain_core.prompts.string import PromptTemplateFormat
from tests.unit_tests.pydantic_utils import _normalize_schema from tests.unit_tests.pydantic_utils import _normalize_schema
@ -298,6 +299,77 @@ def test_chat_prompt_template_from_messages_mustache() -> None:
] ]
@pytest.mark.requires("jinja2")
def test_chat_prompt_template_from_messages_jinja2() -> None:
template = ChatPromptTemplate.from_messages(
[
("system", "You are a helpful AI bot. Your name is {{ name }}."),
("human", "Hello, how are you doing?"),
("ai", "I'm doing well, thanks!"),
("human", "{{ user_input }}"),
],
"jinja2",
)
messages = template.format_messages(name="Bob", user_input="What is your name?")
assert messages == [
SystemMessage(
content="You are a helpful AI bot. Your name is Bob.", additional_kwargs={}
),
HumanMessage(
content="Hello, how are you doing?", additional_kwargs={}, example=False
),
AIMessage(
content="I'm doing well, thanks!", additional_kwargs={}, example=False
),
HumanMessage(content="What is your name?", additional_kwargs={}, example=False),
]
@pytest.mark.requires("jinja2")
@pytest.mark.requires("mustache")
@pytest.mark.parametrize(
"template_format,image_type_placeholder,image_data_placeholder",
[
("f-string", "{image_type}", "{image_data}"),
("mustache", "{{image_type}}", "{{image_data}}"),
("jinja2", "{{ image_type }}", "{{ image_data }}"),
],
)
def test_chat_prompt_template_image_prompt_from_message(
template_format: PromptTemplateFormat,
image_type_placeholder: str,
image_data_placeholder: str,
) -> None:
prompt = {
"type": "image_url",
"image_url": {
"url": f"data:{image_type_placeholder};base64, {image_data_placeholder}",
"detail": "low",
},
}
template = ChatPromptTemplate.from_messages(
[("human", [prompt])], template_format=template_format
)
assert template.format_messages(
image_type="image/png", image_data="base64data"
) == [
HumanMessage(
content=[
{
"type": "image_url",
"image_url": {
"url": "data:image/png;base64, base64data",
"detail": "low",
},
}
]
)
]
def test_chat_prompt_template_with_messages( def test_chat_prompt_template_with_messages(
messages: list[BaseMessagePromptTemplate], messages: list[BaseMessagePromptTemplate],
) -> None: ) -> None:

View File

@ -8,6 +8,7 @@ import pytest
from syrupy import SnapshotAssertion from syrupy import SnapshotAssertion
from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import PromptTemplateFormat
from langchain_core.tracers.run_collector import RunCollectorCallbackHandler from langchain_core.tracers.run_collector import RunCollectorCallbackHandler
from tests.unit_tests.pydantic_utils import _normalize_schema from tests.unit_tests.pydantic_utils import _normalize_schema
@ -610,7 +611,9 @@ async def test_prompt_ainvoke_with_metadata() -> None:
) )
@pytest.mark.parametrize("template_format", ["f-string", "mustache"]) @pytest.mark.parametrize("template_format", ["f-string", "mustache"])
def test_prompt_falsy_vars( def test_prompt_falsy_vars(
template_format: str, value: Any, expected: Union[str, dict[str, str]] template_format: PromptTemplateFormat,
value: Any,
expected: Union[str, dict[str, str]],
) -> None: ) -> None:
# each line is value, f-string, mustache # each line is value, f-string, mustache
if template_format == "f-string": if template_format == "f-string":