mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 09:40:26 +00:00
core: fix Image prompt template hardcoded template format (#27495)
Fixes #27411 **Description:** Adds `template_format` to the `ImagePromptTemplate` class and updates passing in the `template_format` parameter from ChatPromptTemplate instead of the hardcoded "f-string". Also updated docs and typing related to `template_format` to be more up-to-date and specific. **Dependencies:** None **Add tests and docs**: Added unit tests to validate fix. Needed to update `test_chat` snapshot due to adding new attribute `template_format` in `ImagePromptTemplate`. --------- Co-authored-by: Vadym Barda <vadym@langchain.dev>
This commit is contained in:
parent
403c0ea801
commit
380449a7a9
@ -1 +1,2 @@
|
||||
jinja2>=3,<4
|
||||
mustache>=0.1.4,<1
|
||||
|
@ -8,7 +8,6 @@ from pathlib import Path
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
Optional,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
@ -40,7 +39,11 @@ from langchain_core.prompt_values import ChatPromptValue, ImageURL, PromptValue
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.image import ImagePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompts.string import StringPromptTemplate, get_template_variables
|
||||
from langchain_core.prompts.string import (
|
||||
PromptTemplateFormat,
|
||||
StringPromptTemplate,
|
||||
get_template_variables,
|
||||
)
|
||||
from langchain_core.utils import get_colored_text
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
|
||||
@ -296,7 +299,7 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
||||
def from_template(
|
||||
cls: type[MessagePromptTemplateT],
|
||||
template: str,
|
||||
template_format: str = "f-string",
|
||||
template_format: PromptTemplateFormat = "f-string",
|
||||
partial_variables: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> MessagePromptTemplateT:
|
||||
@ -486,7 +489,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
def from_template(
|
||||
cls: type[_StringImageMessagePromptTemplateT],
|
||||
template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
|
||||
template_format: str = "f-string",
|
||||
template_format: PromptTemplateFormat = "f-string",
|
||||
*,
|
||||
partial_variables: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
@ -495,7 +498,8 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
|
||||
Args:
|
||||
template: a template.
|
||||
template_format: format of the template. Defaults to "f-string".
|
||||
template_format: format of the template.
|
||||
Options are: 'f-string', 'mustache', 'jinja2'. Defaults to "f-string".
|
||||
partial_variables: A dictionary of variables that can be used too partially.
|
||||
Defaults to None.
|
||||
**kwargs: keyword arguments to pass to the constructor.
|
||||
@ -533,7 +537,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
img_template = cast(_ImageTemplateParam, tmpl)["image_url"]
|
||||
input_variables = []
|
||||
if isinstance(img_template, str):
|
||||
vars = get_template_variables(img_template, "f-string")
|
||||
vars = get_template_variables(img_template, template_format)
|
||||
if vars:
|
||||
if len(vars) > 1:
|
||||
msg = (
|
||||
@ -545,7 +549,9 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
input_variables = [vars[0]]
|
||||
img_template = {"url": img_template}
|
||||
img_template_obj = ImagePromptTemplate(
|
||||
input_variables=input_variables, template=img_template
|
||||
input_variables=input_variables,
|
||||
template=img_template,
|
||||
template_format=template_format,
|
||||
)
|
||||
elif isinstance(img_template, dict):
|
||||
img_template = dict(img_template)
|
||||
@ -553,11 +559,13 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
if key in img_template:
|
||||
input_variables.extend(
|
||||
get_template_variables(
|
||||
img_template[key], "f-string"
|
||||
img_template[key], template_format
|
||||
)
|
||||
)
|
||||
img_template_obj = ImagePromptTemplate(
|
||||
input_variables=input_variables, template=img_template
|
||||
input_variables=input_variables,
|
||||
template=img_template,
|
||||
template_format=template_format,
|
||||
)
|
||||
else:
|
||||
msg = f"Invalid image template: {tmpl}"
|
||||
@ -943,7 +951,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
self,
|
||||
messages: Sequence[MessageLikeRepresentation],
|
||||
*,
|
||||
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
|
||||
template_format: PromptTemplateFormat = "f-string",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a chat prompt template from a variety of message formats.
|
||||
@ -1160,7 +1168,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
def from_messages(
|
||||
cls,
|
||||
messages: Sequence[MessageLikeRepresentation],
|
||||
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
|
||||
template_format: PromptTemplateFormat = "f-string",
|
||||
) -> ChatPromptTemplate:
|
||||
"""Create a chat prompt template from a variety of message formats.
|
||||
|
||||
@ -1354,7 +1362,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
def _create_template_from_message_type(
|
||||
message_type: str,
|
||||
template: Union[str, list],
|
||||
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
|
||||
template_format: PromptTemplateFormat = "f-string",
|
||||
) -> BaseMessagePromptTemplate:
|
||||
"""Create a message prompt template from a message type and template string.
|
||||
|
||||
@ -1426,7 +1434,7 @@ def _create_template_from_message_type(
|
||||
|
||||
def _convert_to_message(
|
||||
message: MessageLikeRepresentation,
|
||||
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
|
||||
template_format: PromptTemplateFormat = "f-string",
|
||||
) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
|
||||
"""Instantiate a message from a variety of message formats.
|
||||
|
||||
|
@ -9,6 +9,7 @@ from typing_extensions import Self
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompts.string import (
|
||||
DEFAULT_FORMATTER_MAPPING,
|
||||
PromptTemplateFormat,
|
||||
StringPromptTemplate,
|
||||
)
|
||||
|
||||
@ -36,8 +37,9 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
prefix: Optional[StringPromptTemplate] = None
|
||||
"""A PromptTemplate to put before the examples."""
|
||||
|
||||
template_format: str = "f-string"
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
template_format: PromptTemplateFormat = "f-string"
|
||||
"""The format of the prompt template.
|
||||
Options are: 'f-string', 'jinja2', 'mustache'."""
|
||||
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
|
@ -4,6 +4,10 @@ from pydantic import Field
|
||||
|
||||
from langchain_core.prompt_values import ImagePromptValue, ImageURL, PromptValue
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.string import (
|
||||
DEFAULT_FORMATTER_MAPPING,
|
||||
PromptTemplateFormat,
|
||||
)
|
||||
from langchain_core.runnables import run_in_executor
|
||||
from langchain_core.utils import image as image_utils
|
||||
|
||||
@ -13,6 +17,9 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
|
||||
|
||||
template: dict = Field(default_factory=dict)
|
||||
"""Template for the prompt."""
|
||||
template_format: PromptTemplateFormat = "f-string"
|
||||
"""The format of the prompt template.
|
||||
Options are: 'f-string', 'mustache', 'jinja2'."""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
if "input_variables" not in kwargs:
|
||||
@ -85,7 +92,9 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
|
||||
formatted = {}
|
||||
for k, v in self.template.items():
|
||||
if isinstance(v, str):
|
||||
formatted[k] = v.format(**kwargs)
|
||||
formatted[k] = DEFAULT_FORMATTER_MAPPING[self.template_format](
|
||||
v, **kwargs
|
||||
)
|
||||
else:
|
||||
formatted[k] = v
|
||||
url = kwargs.get("url") or formatted.get("url")
|
||||
|
@ -4,12 +4,13 @@ from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from langchain_core.prompts.string import (
|
||||
DEFAULT_FORMATTER_MAPPING,
|
||||
PromptTemplateFormat,
|
||||
StringPromptTemplate,
|
||||
check_valid_template,
|
||||
get_template_variables,
|
||||
@ -24,7 +25,8 @@ class PromptTemplate(StringPromptTemplate):
|
||||
A prompt template consists of a string template. It accepts a set of parameters
|
||||
from the user that can be used to generate a prompt for a language model.
|
||||
|
||||
The template can be formatted using either f-strings (default) or jinja2 syntax.
|
||||
The template can be formatted using either f-strings (default), jinja2,
|
||||
or mustache syntax.
|
||||
|
||||
*Security warning*:
|
||||
Prefer using `template_format="f-string"` instead of
|
||||
@ -67,7 +69,7 @@ class PromptTemplate(StringPromptTemplate):
|
||||
template: str
|
||||
"""The prompt template."""
|
||||
|
||||
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string"
|
||||
template_format: PromptTemplateFormat = "f-string"
|
||||
"""The format of the prompt template.
|
||||
Options are: 'f-string', 'mustache', 'jinja2'."""
|
||||
|
||||
@ -248,7 +250,7 @@ class PromptTemplate(StringPromptTemplate):
|
||||
cls,
|
||||
template: str,
|
||||
*,
|
||||
template_format: str = "f-string",
|
||||
template_format: PromptTemplateFormat = "f-string",
|
||||
partial_variables: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> PromptTemplate:
|
||||
@ -270,7 +272,7 @@ class PromptTemplate(StringPromptTemplate):
|
||||
Args:
|
||||
template: The template to load.
|
||||
template_format: The format of the template. Use `jinja2` for jinja2,
|
||||
and `f-string` or None for f-strings.
|
||||
`mustache` for mustache, and `f-string` for f-strings.
|
||||
Defaults to `f-string`.
|
||||
partial_variables: A dictionary of variables that can be used to partially
|
||||
fill in the template. For example, if the template is
|
||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import warnings
|
||||
from abc import ABC
|
||||
from string import Formatter
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
@ -16,6 +16,8 @@ from langchain_core.utils import get_colored_text
|
||||
from langchain_core.utils.formatting import formatter
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
|
||||
PromptTemplateFormat = Literal["f-string", "mustache", "jinja2"]
|
||||
|
||||
|
||||
def jinja2_formatter(template: str, /, **kwargs: Any) -> str:
|
||||
"""Format a template using jinja2.
|
||||
|
@ -2,7 +2,6 @@ from collections.abc import Iterator, Mapping, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
@ -15,6 +14,7 @@ from langchain_core.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
MessageLikeRepresentation,
|
||||
)
|
||||
from langchain_core.prompts.string import PromptTemplateFormat
|
||||
from langchain_core.runnables.base import (
|
||||
Other,
|
||||
Runnable,
|
||||
@ -38,7 +38,7 @@ class StructuredPrompt(ChatPromptTemplate):
|
||||
schema_: Optional[Union[dict, type[BaseModel]]] = None,
|
||||
*,
|
||||
structured_output_kwargs: Optional[dict[str, Any]] = None,
|
||||
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
|
||||
template_format: PromptTemplateFormat = "f-string",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
schema_ = schema_ or kwargs.pop("schema")
|
||||
|
@ -3119,6 +3119,7 @@
|
||||
'template': dict({
|
||||
'url': 'data:image/jpeg;base64,{my_image}',
|
||||
}),
|
||||
'template_format': 'f-string',
|
||||
}),
|
||||
'lc': 1,
|
||||
'name': 'ImagePromptTemplate',
|
||||
@ -3138,6 +3139,7 @@
|
||||
'template': dict({
|
||||
'url': 'data:image/jpeg;base64,{my_image}',
|
||||
}),
|
||||
'template_format': 'f-string',
|
||||
}),
|
||||
'lc': 1,
|
||||
'name': 'ImagePromptTemplate',
|
||||
@ -3157,6 +3159,7 @@
|
||||
'template': dict({
|
||||
'url': '{my_other_image}',
|
||||
}),
|
||||
'template_format': 'f-string',
|
||||
}),
|
||||
'lc': 1,
|
||||
'name': 'ImagePromptTemplate',
|
||||
@ -3177,6 +3180,7 @@
|
||||
'detail': 'medium',
|
||||
'url': '{my_other_image}',
|
||||
}),
|
||||
'template_format': 'f-string',
|
||||
}),
|
||||
'lc': 1,
|
||||
'name': 'ImagePromptTemplate',
|
||||
@ -3195,6 +3199,7 @@
|
||||
'template': dict({
|
||||
'url': 'https://www.langchain.com/image.png',
|
||||
}),
|
||||
'template_format': 'f-string',
|
||||
}),
|
||||
'lc': 1,
|
||||
'name': 'ImagePromptTemplate',
|
||||
@ -3213,6 +3218,7 @@
|
||||
'template': dict({
|
||||
'url': 'data:image/jpeg;base64,foobar',
|
||||
}),
|
||||
'template_format': 'f-string',
|
||||
}),
|
||||
'lc': 1,
|
||||
'name': 'ImagePromptTemplate',
|
||||
@ -3231,6 +3237,7 @@
|
||||
'template': dict({
|
||||
'url': 'data:image/jpeg;base64,foobar',
|
||||
}),
|
||||
'template_format': 'f-string',
|
||||
}),
|
||||
'lc': 1,
|
||||
'name': 'ImagePromptTemplate',
|
||||
|
@ -31,6 +31,7 @@ from langchain_core.prompts.chat import (
|
||||
SystemMessagePromptTemplate,
|
||||
_convert_to_message,
|
||||
)
|
||||
from langchain_core.prompts.string import PromptTemplateFormat
|
||||
from tests.unit_tests.pydantic_utils import _normalize_schema
|
||||
|
||||
|
||||
@ -298,6 +299,77 @@ def test_chat_prompt_template_from_messages_mustache() -> None:
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.requires("jinja2")
|
||||
def test_chat_prompt_template_from_messages_jinja2() -> None:
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are a helpful AI bot. Your name is {{ name }}."),
|
||||
("human", "Hello, how are you doing?"),
|
||||
("ai", "I'm doing well, thanks!"),
|
||||
("human", "{{ user_input }}"),
|
||||
],
|
||||
"jinja2",
|
||||
)
|
||||
|
||||
messages = template.format_messages(name="Bob", user_input="What is your name?")
|
||||
|
||||
assert messages == [
|
||||
SystemMessage(
|
||||
content="You are a helpful AI bot. Your name is Bob.", additional_kwargs={}
|
||||
),
|
||||
HumanMessage(
|
||||
content="Hello, how are you doing?", additional_kwargs={}, example=False
|
||||
),
|
||||
AIMessage(
|
||||
content="I'm doing well, thanks!", additional_kwargs={}, example=False
|
||||
),
|
||||
HumanMessage(content="What is your name?", additional_kwargs={}, example=False),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.requires("jinja2")
|
||||
@pytest.mark.requires("mustache")
|
||||
@pytest.mark.parametrize(
|
||||
"template_format,image_type_placeholder,image_data_placeholder",
|
||||
[
|
||||
("f-string", "{image_type}", "{image_data}"),
|
||||
("mustache", "{{image_type}}", "{{image_data}}"),
|
||||
("jinja2", "{{ image_type }}", "{{ image_data }}"),
|
||||
],
|
||||
)
|
||||
def test_chat_prompt_template_image_prompt_from_message(
|
||||
template_format: PromptTemplateFormat,
|
||||
image_type_placeholder: str,
|
||||
image_data_placeholder: str,
|
||||
) -> None:
|
||||
prompt = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{image_type_placeholder};base64, {image_data_placeholder}",
|
||||
"detail": "low",
|
||||
},
|
||||
}
|
||||
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
[("human", [prompt])], template_format=template_format
|
||||
)
|
||||
assert template.format_messages(
|
||||
image_type="image/png", image_data="base64data"
|
||||
) == [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/png;base64, base64data",
|
||||
"detail": "low",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test_chat_prompt_template_with_messages(
|
||||
messages: list[BaseMessagePromptTemplate],
|
||||
) -> None:
|
||||
|
@ -8,6 +8,7 @@ import pytest
|
||||
from syrupy import SnapshotAssertion
|
||||
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompts.string import PromptTemplateFormat
|
||||
from langchain_core.tracers.run_collector import RunCollectorCallbackHandler
|
||||
from tests.unit_tests.pydantic_utils import _normalize_schema
|
||||
|
||||
@ -610,7 +611,9 @@ async def test_prompt_ainvoke_with_metadata() -> None:
|
||||
)
|
||||
@pytest.mark.parametrize("template_format", ["f-string", "mustache"])
|
||||
def test_prompt_falsy_vars(
|
||||
template_format: str, value: Any, expected: Union[str, dict[str, str]]
|
||||
template_format: PromptTemplateFormat,
|
||||
value: Any,
|
||||
expected: Union[str, dict[str, str]],
|
||||
) -> None:
|
||||
# each line is value, f-string, mustache
|
||||
if template_format == "f-string":
|
||||
|
Loading…
Reference in New Issue
Block a user