core: fix Image prompt template hardcoded template format (#27495)

Fixes #27411 

**Description:** Adds `template_format` to the `ImagePromptTemplate`
class and updates passing in the `template_format` parameter from
ChatPromptTemplate instead of the hardcoded "f-string".
Also updated docs and typing related to `template_format` to be more
up-to-date and specific.

**Dependencies:** None

**Add tests and docs**: Added unit tests to validate fix. Needed to
update `test_chat` snapshot due to adding new attribute
`template_format` in `ImagePromptTemplate`.

---------

Co-authored-by: Vadym Barda <vadym@langchain.dev>
This commit is contained in:
Chun Kang Lu 2024-10-21 17:31:40 -04:00 committed by GitHub
parent 403c0ea801
commit 380449a7a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 131 additions and 25 deletions

View File

@ -1 +1,2 @@
jinja2>=3,<4
mustache>=0.1.4,<1

View File

@ -8,7 +8,6 @@ from pathlib import Path
from typing import (
Annotated,
Any,
Literal,
Optional,
TypedDict,
TypeVar,
@ -40,7 +39,11 @@ from langchain_core.prompt_values import ChatPromptValue, ImageURL, PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.image import ImagePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import StringPromptTemplate, get_template_variables
from langchain_core.prompts.string import (
PromptTemplateFormat,
StringPromptTemplate,
get_template_variables,
)
from langchain_core.utils import get_colored_text
from langchain_core.utils.interactive_env import is_interactive_env
@ -296,7 +299,7 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
def from_template(
cls: type[MessagePromptTemplateT],
template: str,
template_format: str = "f-string",
template_format: PromptTemplateFormat = "f-string",
partial_variables: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> MessagePromptTemplateT:
@ -486,7 +489,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
def from_template(
cls: type[_StringImageMessagePromptTemplateT],
template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
template_format: str = "f-string",
template_format: PromptTemplateFormat = "f-string",
*,
partial_variables: Optional[dict[str, Any]] = None,
**kwargs: Any,
@ -495,7 +498,8 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
Args:
template: a template.
template_format: format of the template. Defaults to "f-string".
template_format: format of the template.
Options are: 'f-string', 'mustache', 'jinja2'. Defaults to "f-string".
partial_variables: A dictionary of variables that can be used too partially.
Defaults to None.
**kwargs: keyword arguments to pass to the constructor.
@ -533,7 +537,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
img_template = cast(_ImageTemplateParam, tmpl)["image_url"]
input_variables = []
if isinstance(img_template, str):
vars = get_template_variables(img_template, "f-string")
vars = get_template_variables(img_template, template_format)
if vars:
if len(vars) > 1:
msg = (
@ -545,7 +549,9 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
input_variables = [vars[0]]
img_template = {"url": img_template}
img_template_obj = ImagePromptTemplate(
input_variables=input_variables, template=img_template
input_variables=input_variables,
template=img_template,
template_format=template_format,
)
elif isinstance(img_template, dict):
img_template = dict(img_template)
@ -553,11 +559,13 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
if key in img_template:
input_variables.extend(
get_template_variables(
img_template[key], "f-string"
img_template[key], template_format
)
)
img_template_obj = ImagePromptTemplate(
input_variables=input_variables, template=img_template
input_variables=input_variables,
template=img_template,
template_format=template_format,
)
else:
msg = f"Invalid image template: {tmpl}"
@ -943,7 +951,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
self,
messages: Sequence[MessageLikeRepresentation],
*,
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
template_format: PromptTemplateFormat = "f-string",
**kwargs: Any,
) -> None:
"""Create a chat prompt template from a variety of message formats.
@ -1160,7 +1168,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
def from_messages(
cls,
messages: Sequence[MessageLikeRepresentation],
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
template_format: PromptTemplateFormat = "f-string",
) -> ChatPromptTemplate:
"""Create a chat prompt template from a variety of message formats.
@ -1354,7 +1362,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
def _create_template_from_message_type(
message_type: str,
template: Union[str, list],
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
template_format: PromptTemplateFormat = "f-string",
) -> BaseMessagePromptTemplate:
"""Create a message prompt template from a message type and template string.
@ -1426,7 +1434,7 @@ def _create_template_from_message_type(
def _convert_to_message(
message: MessageLikeRepresentation,
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
template_format: PromptTemplateFormat = "f-string",
) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
"""Instantiate a message from a variety of message formats.

View File

@ -9,6 +9,7 @@ from typing_extensions import Self
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,
PromptTemplateFormat,
StringPromptTemplate,
)
@ -36,8 +37,9 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
prefix: Optional[StringPromptTemplate] = None
"""A PromptTemplate to put before the examples."""
template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
template_format: PromptTemplateFormat = "f-string"
"""The format of the prompt template.
Options are: 'f-string', 'jinja2', 'mustache'."""
validate_template: bool = False
"""Whether or not to try validating the template."""

View File

@ -4,6 +4,10 @@ from pydantic import Field
from langchain_core.prompt_values import ImagePromptValue, ImageURL, PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,
PromptTemplateFormat,
)
from langchain_core.runnables import run_in_executor
from langchain_core.utils import image as image_utils
@ -13,6 +17,9 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
template: dict = Field(default_factory=dict)
"""Template for the prompt."""
template_format: PromptTemplateFormat = "f-string"
"""The format of the prompt template.
Options are: 'f-string', 'mustache', 'jinja2'."""
def __init__(self, **kwargs: Any) -> None:
if "input_variables" not in kwargs:
@ -85,7 +92,9 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
formatted = {}
for k, v in self.template.items():
if isinstance(v, str):
formatted[k] = v.format(**kwargs)
formatted[k] = DEFAULT_FORMATTER_MAPPING[self.template_format](
v, **kwargs
)
else:
formatted[k] = v
url = kwargs.get("url") or formatted.get("url")

View File

@ -4,12 +4,13 @@ from __future__ import annotations
import warnings
from pathlib import Path
from typing import Any, Literal, Optional, Union
from typing import Any, Optional, Union
from pydantic import BaseModel, model_validator
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,
PromptTemplateFormat,
StringPromptTemplate,
check_valid_template,
get_template_variables,
@ -24,7 +25,8 @@ class PromptTemplate(StringPromptTemplate):
A prompt template consists of a string template. It accepts a set of parameters
from the user that can be used to generate a prompt for a language model.
The template can be formatted using either f-strings (default) or jinja2 syntax.
The template can be formatted using either f-strings (default), jinja2,
or mustache syntax.
*Security warning*:
Prefer using `template_format="f-string"` instead of
@ -67,7 +69,7 @@ class PromptTemplate(StringPromptTemplate):
template: str
"""The prompt template."""
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string"
template_format: PromptTemplateFormat = "f-string"
"""The format of the prompt template.
Options are: 'f-string', 'mustache', 'jinja2'."""
@ -248,7 +250,7 @@ class PromptTemplate(StringPromptTemplate):
cls,
template: str,
*,
template_format: str = "f-string",
template_format: PromptTemplateFormat = "f-string",
partial_variables: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> PromptTemplate:
@ -270,7 +272,7 @@ class PromptTemplate(StringPromptTemplate):
Args:
template: The template to load.
template_format: The format of the template. Use `jinja2` for jinja2,
and `f-string` or None for f-strings.
`mustache` for mustache, and `f-string` for f-strings.
Defaults to `f-string`.
partial_variables: A dictionary of variables that can be used to partially
fill in the template. For example, if the template is

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import warnings
from abc import ABC
from string import Formatter
from typing import Any, Callable
from typing import Any, Callable, Literal
from pydantic import BaseModel, create_model
@ -16,6 +16,8 @@ from langchain_core.utils import get_colored_text
from langchain_core.utils.formatting import formatter
from langchain_core.utils.interactive_env import is_interactive_env
PromptTemplateFormat = Literal["f-string", "mustache", "jinja2"]
def jinja2_formatter(template: str, /, **kwargs: Any) -> str:
"""Format a template using jinja2.

View File

@ -2,7 +2,6 @@ from collections.abc import Iterator, Mapping, Sequence
from typing import (
Any,
Callable,
Literal,
Optional,
Union,
)
@ -15,6 +14,7 @@ from langchain_core.prompts.chat import (
ChatPromptTemplate,
MessageLikeRepresentation,
)
from langchain_core.prompts.string import PromptTemplateFormat
from langchain_core.runnables.base import (
Other,
Runnable,
@ -38,7 +38,7 @@ class StructuredPrompt(ChatPromptTemplate):
schema_: Optional[Union[dict, type[BaseModel]]] = None,
*,
structured_output_kwargs: Optional[dict[str, Any]] = None,
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
template_format: PromptTemplateFormat = "f-string",
**kwargs: Any,
) -> None:
schema_ = schema_ or kwargs.pop("schema")

View File

@ -3119,6 +3119,7 @@
'template': dict({
'url': 'data:image/jpeg;base64,{my_image}',
}),
'template_format': 'f-string',
}),
'lc': 1,
'name': 'ImagePromptTemplate',
@ -3138,6 +3139,7 @@
'template': dict({
'url': 'data:image/jpeg;base64,{my_image}',
}),
'template_format': 'f-string',
}),
'lc': 1,
'name': 'ImagePromptTemplate',
@ -3157,6 +3159,7 @@
'template': dict({
'url': '{my_other_image}',
}),
'template_format': 'f-string',
}),
'lc': 1,
'name': 'ImagePromptTemplate',
@ -3177,6 +3180,7 @@
'detail': 'medium',
'url': '{my_other_image}',
}),
'template_format': 'f-string',
}),
'lc': 1,
'name': 'ImagePromptTemplate',
@ -3195,6 +3199,7 @@
'template': dict({
'url': 'https://www.langchain.com/image.png',
}),
'template_format': 'f-string',
}),
'lc': 1,
'name': 'ImagePromptTemplate',
@ -3213,6 +3218,7 @@
'template': dict({
'url': 'data:image/jpeg;base64,foobar',
}),
'template_format': 'f-string',
}),
'lc': 1,
'name': 'ImagePromptTemplate',
@ -3231,6 +3237,7 @@
'template': dict({
'url': 'data:image/jpeg;base64,foobar',
}),
'template_format': 'f-string',
}),
'lc': 1,
'name': 'ImagePromptTemplate',

View File

@ -31,6 +31,7 @@ from langchain_core.prompts.chat import (
SystemMessagePromptTemplate,
_convert_to_message,
)
from langchain_core.prompts.string import PromptTemplateFormat
from tests.unit_tests.pydantic_utils import _normalize_schema
@ -298,6 +299,77 @@ def test_chat_prompt_template_from_messages_mustache() -> None:
]
@pytest.mark.requires("jinja2")
def test_chat_prompt_template_from_messages_jinja2() -> None:
template = ChatPromptTemplate.from_messages(
[
("system", "You are a helpful AI bot. Your name is {{ name }}."),
("human", "Hello, how are you doing?"),
("ai", "I'm doing well, thanks!"),
("human", "{{ user_input }}"),
],
"jinja2",
)
messages = template.format_messages(name="Bob", user_input="What is your name?")
assert messages == [
SystemMessage(
content="You are a helpful AI bot. Your name is Bob.", additional_kwargs={}
),
HumanMessage(
content="Hello, how are you doing?", additional_kwargs={}, example=False
),
AIMessage(
content="I'm doing well, thanks!", additional_kwargs={}, example=False
),
HumanMessage(content="What is your name?", additional_kwargs={}, example=False),
]
@pytest.mark.requires("jinja2")
@pytest.mark.requires("mustache")
@pytest.mark.parametrize(
"template_format,image_type_placeholder,image_data_placeholder",
[
("f-string", "{image_type}", "{image_data}"),
("mustache", "{{image_type}}", "{{image_data}}"),
("jinja2", "{{ image_type }}", "{{ image_data }}"),
],
)
def test_chat_prompt_template_image_prompt_from_message(
template_format: PromptTemplateFormat,
image_type_placeholder: str,
image_data_placeholder: str,
) -> None:
prompt = {
"type": "image_url",
"image_url": {
"url": f"data:{image_type_placeholder};base64, {image_data_placeholder}",
"detail": "low",
},
}
template = ChatPromptTemplate.from_messages(
[("human", [prompt])], template_format=template_format
)
assert template.format_messages(
image_type="image/png", image_data="base64data"
) == [
HumanMessage(
content=[
{
"type": "image_url",
"image_url": {
"url": "data:image/png;base64, base64data",
"detail": "low",
},
}
]
)
]
def test_chat_prompt_template_with_messages(
messages: list[BaseMessagePromptTemplate],
) -> None:

View File

@ -8,6 +8,7 @@ import pytest
from syrupy import SnapshotAssertion
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import PromptTemplateFormat
from langchain_core.tracers.run_collector import RunCollectorCallbackHandler
from tests.unit_tests.pydantic_utils import _normalize_schema
@ -610,7 +611,9 @@ async def test_prompt_ainvoke_with_metadata() -> None:
)
@pytest.mark.parametrize("template_format", ["f-string", "mustache"])
def test_prompt_falsy_vars(
template_format: str, value: Any, expected: Union[str, dict[str, str]]
template_format: PromptTemplateFormat,
value: Any,
expected: Union[str, dict[str, str]],
) -> None:
# each line is value, f-string, mustache
if template_format == "f-string":