mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 12:18:24 +00:00
core[minor]: Image prompt template (#14263)
Builds on Bagatur's (#13227). See unit test for example usage (below) ```python def test_chat_tmpl_from_messages_multipart_image() -> None: base64_image = "abcd123" other_base64_image = "abcd123" template = ChatPromptTemplate.from_messages( [ ("system", "You are an AI assistant named {name}."), ( "human", [ {"type": "text", "text": "What's in this image?"}, # OAI supports all these structures today { "type": "image_url", "image_url": "data:image/jpeg;base64,{my_image}", }, { "type": "image_url", "image_url": {"url": "data:image/jpeg;base64,{my_image}"}, }, {"type": "image_url", "image_url": "{my_other_image}"}, { "type": "image_url", "image_url": {"url": "{my_other_image}", "detail": "medium"}, }, { "type": "image_url", "image_url": {"url": "https://www.langchain.com/image.png"}, }, { "type": "image_url", "image_url": {"url": ""}, }, ], ), ] ) messages = template.format_messages( name="R2D2", my_image=base64_image, my_other_image=other_base64_image ) expected = [ SystemMessage(content="You are an AI assistant named R2D2."), HumanMessage( content=[ {"type": "text", "text": "What's in this image?"}, { "type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}, }, { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{other_base64_image}" }, }, { "type": "image_url", "image_url": {"url": f"{other_base64_image}"}, }, { "type": "image_url", "image_url": { "url": f"{other_base64_image}", "detail": "medium", }, }, { "type": "image_url", "image_url": {"url": "https://www.langchain.com/image.png"}, }, { "type": "image_url", "image_url": {"url": ""}, }, ] ), ] assert messages == expected ``` --------- Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Brace Sproul <braceasproul@gmail.com>
This commit is contained in:
parent
3c387bc12d
commit
38425c99d2
@ -3,6 +3,8 @@ from __future__ import annotations
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Literal, Sequence
|
from typing import List, Literal, Sequence
|
||||||
|
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from langchain_core.load.serializable import Serializable
|
from langchain_core.load.serializable import Serializable
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AnyMessage,
|
AnyMessage,
|
||||||
@ -82,6 +84,30 @@ class ChatPromptValue(PromptValue):
|
|||||||
return ["langchain", "prompts", "chat"]
|
return ["langchain", "prompts", "chat"]
|
||||||
|
|
||||||
|
|
||||||
|
class ImageURL(TypedDict, total=False):
|
||||||
|
detail: Literal["auto", "low", "high"]
|
||||||
|
"""Specifies the detail level of the image."""
|
||||||
|
|
||||||
|
url: str
|
||||||
|
"""Either a URL of the image or the base64 encoded image data."""
|
||||||
|
|
||||||
|
|
||||||
|
class ImagePromptValue(PromptValue):
|
||||||
|
"""Image prompt value."""
|
||||||
|
|
||||||
|
image_url: ImageURL
|
||||||
|
"""Prompt image."""
|
||||||
|
type: Literal["ImagePromptValue"] = "ImagePromptValue"
|
||||||
|
|
||||||
|
def to_string(self) -> str:
|
||||||
|
"""Return prompt as string."""
|
||||||
|
return self.image_url["url"]
|
||||||
|
|
||||||
|
def to_messages(self) -> List[BaseMessage]:
|
||||||
|
"""Return prompt as messages."""
|
||||||
|
return [HumanMessage(content=[self.image_url])]
|
||||||
|
|
||||||
|
|
||||||
class ChatPromptValueConcrete(ChatPromptValue):
|
class ChatPromptValueConcrete(ChatPromptValue):
|
||||||
"""Chat prompt value which explicitly lists out the message types it accepts.
|
"""Chat prompt value which explicitly lists out the message types it accepts.
|
||||||
For use in external schemas."""
|
For use in external schemas."""
|
||||||
|
@ -8,10 +8,12 @@ from typing import (
|
|||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
|
Generic,
|
||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Type,
|
Type,
|
||||||
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -30,7 +32,12 @@ if TYPE_CHECKING:
|
|||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
|
|
||||||
class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
FormatOutputType = TypeVar("FormatOutputType")
|
||||||
|
|
||||||
|
|
||||||
|
class BasePromptTemplate(
|
||||||
|
RunnableSerializable[Dict, PromptValue], Generic[FormatOutputType], ABC
|
||||||
|
):
|
||||||
"""Base class for all prompt templates, returning a prompt."""
|
"""Base class for all prompt templates, returning a prompt."""
|
||||||
|
|
||||||
input_variables: List[str]
|
input_variables: List[str]
|
||||||
@ -142,7 +149,7 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
|||||||
return {**partial_kwargs, **kwargs}
|
return {**partial_kwargs, **kwargs}
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def format(self, **kwargs: Any) -> str:
|
def format(self, **kwargs: Any) -> FormatOutputType:
|
||||||
"""Format the prompt with the inputs.
|
"""Format the prompt with the inputs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -210,7 +217,7 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
|
|||||||
raise ValueError(f"{save_path} must be json or yaml")
|
raise ValueError(f"{save_path} must be json or yaml")
|
||||||
|
|
||||||
|
|
||||||
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
def format_document(doc: Document, prompt: BasePromptTemplate[str]) -> str:
|
||||||
"""Format a document into a string based on a prompt template.
|
"""Format a document into a string based on a prompt template.
|
||||||
|
|
||||||
First, this pulls information from the document from two sources:
|
First, this pulls information from the document from two sources:
|
||||||
@ -236,7 +243,7 @@ def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
|||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain_core import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
|
||||||
doc = Document(page_content="This is a joke", metadata={"page": "1"})
|
doc = Document(page_content="This is a joke", metadata={"page": "1"})
|
||||||
|
@ -13,8 +13,10 @@ from typing import (
|
|||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
|
TypedDict,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
|
cast,
|
||||||
overload,
|
overload,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -30,10 +32,11 @@ from langchain_core.messages import (
|
|||||||
convert_to_messages,
|
convert_to_messages,
|
||||||
)
|
)
|
||||||
from langchain_core.messages.base import get_msg_title_repr
|
from langchain_core.messages.base import get_msg_title_repr
|
||||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue
|
from langchain_core.prompt_values import ChatPromptValue, ImageURL, PromptValue
|
||||||
from langchain_core.prompts.base import BasePromptTemplate
|
from langchain_core.prompts.base import BasePromptTemplate
|
||||||
|
from langchain_core.prompts.image import ImagePromptTemplate
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
from langchain_core.prompts.string import StringPromptTemplate
|
from langchain_core.prompts.string import StringPromptTemplate, get_template_variables
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator
|
from langchain_core.pydantic_v1 import Field, root_validator
|
||||||
from langchain_core.utils import get_colored_text
|
from langchain_core.utils import get_colored_text
|
||||||
from langchain_core.utils.interactive_env import is_interactive_env
|
from langchain_core.utils.interactive_env import is_interactive_env
|
||||||
@ -288,14 +291,153 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
_StringImageMessagePromptTemplateT = TypeVar(
|
||||||
|
"_StringImageMessagePromptTemplateT", bound="_StringImageMessagePromptTemplate"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _TextTemplateParam(TypedDict, total=False):
|
||||||
|
text: Union[str, Dict]
|
||||||
|
|
||||||
|
|
||||||
|
class _ImageTemplateParam(TypedDict, total=False):
|
||||||
|
image_url: Union[str, Dict]
|
||||||
|
|
||||||
|
|
||||||
|
class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||||
"""Human message prompt template. This is a message sent from the user."""
|
"""Human message prompt template. This is a message sent from the user."""
|
||||||
|
|
||||||
|
prompt: Union[
|
||||||
|
StringPromptTemplate, List[Union[StringPromptTemplate, ImagePromptTemplate]]
|
||||||
|
]
|
||||||
|
"""Prompt template."""
|
||||||
|
additional_kwargs: dict = Field(default_factory=dict)
|
||||||
|
"""Additional keyword arguments to pass to the prompt template."""
|
||||||
|
|
||||||
|
_msg_class: Type[BaseMessage]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> List[str]:
|
||||||
"""Get the namespace of the langchain object."""
|
"""Get the namespace of the langchain object."""
|
||||||
return ["langchain", "prompts", "chat"]
|
return ["langchain", "prompts", "chat"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_template(
|
||||||
|
cls: Type[_StringImageMessagePromptTemplateT],
|
||||||
|
template: Union[str, List[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
|
||||||
|
template_format: str = "f-string",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> _StringImageMessagePromptTemplateT:
|
||||||
|
"""Create a class from a string template.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template: a template.
|
||||||
|
template_format: format of the template.
|
||||||
|
**kwargs: keyword arguments to pass to the constructor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new instance of this class.
|
||||||
|
"""
|
||||||
|
if isinstance(template, str):
|
||||||
|
prompt: Union[StringPromptTemplate, List] = PromptTemplate.from_template(
|
||||||
|
template, template_format=template_format
|
||||||
|
)
|
||||||
|
return cls(prompt=prompt, **kwargs)
|
||||||
|
elif isinstance(template, list):
|
||||||
|
prompt = []
|
||||||
|
for tmpl in template:
|
||||||
|
if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl:
|
||||||
|
if isinstance(tmpl, str):
|
||||||
|
text: str = tmpl
|
||||||
|
else:
|
||||||
|
text = cast(_TextTemplateParam, tmpl)["text"] # type: ignore[assignment] # noqa: E501
|
||||||
|
prompt.append(
|
||||||
|
PromptTemplate.from_template(
|
||||||
|
text, template_format=template_format
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(tmpl, dict) and "image_url" in tmpl:
|
||||||
|
img_template = cast(_ImageTemplateParam, tmpl)["image_url"]
|
||||||
|
if isinstance(img_template, str):
|
||||||
|
vars = get_template_variables(img_template, "f-string")
|
||||||
|
if vars:
|
||||||
|
if len(vars) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Only one format variable allowed per image"
|
||||||
|
f" template.\nGot: {vars}"
|
||||||
|
f"\nFrom: {tmpl}"
|
||||||
|
)
|
||||||
|
input_variables = [vars[0]]
|
||||||
|
else:
|
||||||
|
input_variables = None
|
||||||
|
img_template = {"url": img_template}
|
||||||
|
img_template_obj = ImagePromptTemplate(
|
||||||
|
input_variables=input_variables, template=img_template
|
||||||
|
)
|
||||||
|
elif isinstance(img_template, dict):
|
||||||
|
img_template = dict(img_template)
|
||||||
|
if "url" in img_template:
|
||||||
|
input_variables = get_template_variables(
|
||||||
|
img_template["url"], "f-string"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
input_variables = None
|
||||||
|
img_template_obj = ImagePromptTemplate(
|
||||||
|
input_variables=input_variables, template=img_template
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
prompt.append(img_template_obj)
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
return cls(prompt=prompt, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_template_file(
|
||||||
|
cls: Type[_StringImageMessagePromptTemplateT],
|
||||||
|
template_file: Union[str, Path],
|
||||||
|
input_variables: List[str],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> _StringImageMessagePromptTemplateT:
|
||||||
|
"""Create a class from a template file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_file: path to a template file. String or Path.
|
||||||
|
input_variables: list of input variables.
|
||||||
|
**kwargs: keyword arguments to pass to the constructor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new instance of this class.
|
||||||
|
"""
|
||||||
|
with open(str(template_file), "r") as f:
|
||||||
|
template = f.read()
|
||||||
|
return cls.from_template(template, input_variables=input_variables, **kwargs)
|
||||||
|
|
||||||
|
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||||
|
"""Format messages from kwargs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Keyword arguments to use for formatting.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of BaseMessages.
|
||||||
|
"""
|
||||||
|
return [self.format(**kwargs)]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_variables(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Input variables for this prompt template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of input variable names.
|
||||||
|
"""
|
||||||
|
prompts = self.prompt if isinstance(self.prompt, list) else [self.prompt]
|
||||||
|
input_variables = [iv for prompt in prompts for iv in prompt.input_variables]
|
||||||
|
return input_variables
|
||||||
|
|
||||||
def format(self, **kwargs: Any) -> BaseMessage:
|
def format(self, **kwargs: Any) -> BaseMessage:
|
||||||
"""Format the prompt template.
|
"""Format the prompt template.
|
||||||
|
|
||||||
@ -305,53 +447,55 @@ class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
|||||||
Returns:
|
Returns:
|
||||||
Formatted message.
|
Formatted message.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(self.prompt, StringPromptTemplate):
|
||||||
text = self.prompt.format(**kwargs)
|
text = self.prompt.format(**kwargs)
|
||||||
return HumanMessage(content=text, additional_kwargs=self.additional_kwargs)
|
return self._msg_class(
|
||||||
|
content=text, additional_kwargs=self.additional_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
content = []
|
||||||
|
for prompt in self.prompt:
|
||||||
|
inputs = {var: kwargs[var] for var in prompt.input_variables}
|
||||||
|
if isinstance(prompt, StringPromptTemplate):
|
||||||
|
formatted: Union[str, ImageURL] = prompt.format(**inputs)
|
||||||
|
content.append({"type": "text", "text": formatted})
|
||||||
|
elif isinstance(prompt, ImagePromptTemplate):
|
||||||
|
formatted = prompt.format(**inputs)
|
||||||
|
content.append({"type": "image_url", "image_url": formatted})
|
||||||
|
return self._msg_class(
|
||||||
|
content=content, additional_kwargs=self.additional_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
class HumanMessagePromptTemplate(_StringImageMessagePromptTemplate):
|
||||||
|
"""Human message prompt template. This is a message sent from the user."""
|
||||||
|
|
||||||
|
_msg_class: Type[BaseMessage] = HumanMessage
|
||||||
|
|
||||||
|
|
||||||
|
class AIMessagePromptTemplate(_StringImageMessagePromptTemplate):
|
||||||
"""AI message prompt template. This is a message sent from the AI."""
|
"""AI message prompt template. This is a message sent from the AI."""
|
||||||
|
|
||||||
|
_msg_class: Type[BaseMessage] = AIMessage
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> List[str]:
|
||||||
"""Get the namespace of the langchain object."""
|
"""Get the namespace of the langchain object."""
|
||||||
return ["langchain", "prompts", "chat"]
|
return ["langchain", "prompts", "chat"]
|
||||||
|
|
||||||
def format(self, **kwargs: Any) -> BaseMessage:
|
|
||||||
"""Format the prompt template.
|
|
||||||
|
|
||||||
Args:
|
class SystemMessagePromptTemplate(_StringImageMessagePromptTemplate):
|
||||||
**kwargs: Keyword arguments to use for formatting.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted message.
|
|
||||||
"""
|
|
||||||
text = self.prompt.format(**kwargs)
|
|
||||||
return AIMessage(content=text, additional_kwargs=self.additional_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
|
||||||
"""System message prompt template.
|
"""System message prompt template.
|
||||||
This is a message that is not sent to the user.
|
This is a message that is not sent to the user.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_msg_class: Type[BaseMessage] = SystemMessage
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> List[str]:
|
||||||
"""Get the namespace of the langchain object."""
|
"""Get the namespace of the langchain object."""
|
||||||
return ["langchain", "prompts", "chat"]
|
return ["langchain", "prompts", "chat"]
|
||||||
|
|
||||||
def format(self, **kwargs: Any) -> BaseMessage:
|
|
||||||
"""Format the prompt template.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
**kwargs: Keyword arguments to use for formatting.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted message.
|
|
||||||
"""
|
|
||||||
text = self.prompt.format(**kwargs)
|
|
||||||
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
||||||
"""Base class for chat prompt templates."""
|
"""Base class for chat prompt templates."""
|
||||||
@ -405,8 +549,7 @@ MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTempla
|
|||||||
|
|
||||||
MessageLikeRepresentation = Union[
|
MessageLikeRepresentation = Union[
|
||||||
MessageLike,
|
MessageLike,
|
||||||
Tuple[str, str],
|
Tuple[Union[str, Type], Union[str, List[dict], List[object]]],
|
||||||
Tuple[Type, str],
|
|
||||||
str,
|
str,
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -738,7 +881,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
|||||||
|
|
||||||
|
|
||||||
def _create_template_from_message_type(
|
def _create_template_from_message_type(
|
||||||
message_type: str, template: str
|
message_type: str, template: Union[str, list]
|
||||||
) -> BaseMessagePromptTemplate:
|
) -> BaseMessagePromptTemplate:
|
||||||
"""Create a message prompt template from a message type and template string.
|
"""Create a message prompt template from a message type and template string.
|
||||||
|
|
||||||
@ -754,9 +897,9 @@ def _create_template_from_message_type(
|
|||||||
template
|
template
|
||||||
)
|
)
|
||||||
elif message_type in ("ai", "assistant"):
|
elif message_type in ("ai", "assistant"):
|
||||||
message = AIMessagePromptTemplate.from_template(template)
|
message = AIMessagePromptTemplate.from_template(cast(str, template))
|
||||||
elif message_type == "system":
|
elif message_type == "system":
|
||||||
message = SystemMessagePromptTemplate.from_template(template)
|
message = SystemMessagePromptTemplate.from_template(cast(str, template))
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unexpected message type: {message_type}. Use one of 'human',"
|
f"Unexpected message type: {message_type}. Use one of 'human',"
|
||||||
@ -799,7 +942,9 @@ def _convert_to_message(
|
|||||||
if isinstance(message_type_str, str):
|
if isinstance(message_type_str, str):
|
||||||
_message = _create_template_from_message_type(message_type_str, template)
|
_message = _create_template_from_message_type(message_type_str, template)
|
||||||
else:
|
else:
|
||||||
_message = message_type_str(prompt=PromptTemplate.from_template(template))
|
_message = message_type_str(
|
||||||
|
prompt=PromptTemplate.from_template(cast(str, template))
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unsupported message type: {type(message)}")
|
raise NotImplementedError(f"Unsupported message type: {type(message)}")
|
||||||
|
|
||||||
|
76
libs/core/langchain_core/prompts/image.py
Normal file
76
libs/core/langchain_core/prompts/image.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.prompt_values import ImagePromptValue, ImageURL, PromptValue
|
||||||
|
from langchain_core.prompts.base import BasePromptTemplate
|
||||||
|
from langchain_core.pydantic_v1 import Field
|
||||||
|
from langchain_core.utils import image as image_utils
|
||||||
|
|
||||||
|
|
||||||
|
class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
|
||||||
|
"""An image prompt template for a multimodal model."""
|
||||||
|
|
||||||
|
template: dict = Field(default_factory=dict)
|
||||||
|
"""Template for the prompt."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
if "input_variables" not in kwargs:
|
||||||
|
kwargs["input_variables"] = []
|
||||||
|
|
||||||
|
overlap = set(kwargs["input_variables"]) & set(("url", "path", "detail"))
|
||||||
|
if overlap:
|
||||||
|
raise ValueError(
|
||||||
|
"input_variables for the image template cannot contain"
|
||||||
|
" any of 'url', 'path', or 'detail'."
|
||||||
|
f" Found: {overlap}"
|
||||||
|
)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _prompt_type(self) -> str:
|
||||||
|
"""Return the prompt type key."""
|
||||||
|
return "image-prompt"
|
||||||
|
|
||||||
|
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||||
|
"""Create Chat Messages."""
|
||||||
|
return ImagePromptValue(image_url=self.format(**kwargs))
|
||||||
|
|
||||||
|
def format(
|
||||||
|
self,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ImageURL:
|
||||||
|
"""Format the prompt with the inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kwargs: Any arguments to be passed to the prompt template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A formatted string.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
prompt.format(variable1="foo")
|
||||||
|
"""
|
||||||
|
formatted = {}
|
||||||
|
for k, v in self.template.items():
|
||||||
|
if isinstance(v, str):
|
||||||
|
formatted[k] = v.format(**kwargs)
|
||||||
|
else:
|
||||||
|
formatted[k] = v
|
||||||
|
url = kwargs.get("url") or formatted.get("url")
|
||||||
|
path = kwargs.get("path") or formatted.get("path")
|
||||||
|
detail = kwargs.get("detail") or formatted.get("detail")
|
||||||
|
if not url and not path:
|
||||||
|
raise ValueError("Must provide either url or path.")
|
||||||
|
if not url:
|
||||||
|
if not isinstance(path, str):
|
||||||
|
raise ValueError("path must be a string.")
|
||||||
|
url = image_utils.image_to_data_url(path)
|
||||||
|
if not isinstance(url, str):
|
||||||
|
raise ValueError("url must be a string.")
|
||||||
|
output: ImageURL = {"url": url}
|
||||||
|
if detail:
|
||||||
|
# Don't check literal values here: let the API check them
|
||||||
|
output["detail"] = detail # type: ignore[typeddict-item]
|
||||||
|
return output
|
@ -4,6 +4,7 @@
|
|||||||
These functions do not depend on any other LangChain module.
|
These functions do not depend on any other LangChain module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from langchain_core.utils import image
|
||||||
from langchain_core.utils.env import get_from_dict_or_env, get_from_env
|
from langchain_core.utils.env import get_from_dict_or_env, get_from_env
|
||||||
from langchain_core.utils.formatting import StrictFormatter, formatter
|
from langchain_core.utils.formatting import StrictFormatter, formatter
|
||||||
from langchain_core.utils.input import (
|
from langchain_core.utils.input import (
|
||||||
@ -41,6 +42,7 @@ __all__ = [
|
|||||||
"xor_args",
|
"xor_args",
|
||||||
"try_load_from_hub",
|
"try_load_from_hub",
|
||||||
"build_extra_kwargs",
|
"build_extra_kwargs",
|
||||||
|
"image",
|
||||||
"get_from_env",
|
"get_from_env",
|
||||||
"get_from_dict_or_env",
|
"get_from_dict_or_env",
|
||||||
"stringify_dict",
|
"stringify_dict",
|
||||||
|
14
libs/core/langchain_core/utils/image.py
Normal file
14
libs/core/langchain_core/utils/image.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import base64
|
||||||
|
import mimetypes
|
||||||
|
|
||||||
|
|
||||||
|
def encode_image(image_path: str) -> str:
|
||||||
|
"""Get base64 string from image URI."""
|
||||||
|
with open(image_path, "rb") as image_file:
|
||||||
|
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def image_to_data_url(image_path: str) -> str:
|
||||||
|
encoding = encode_image(image_path)
|
||||||
|
mime_type = mimetypes.guess_type(image_path)[0]
|
||||||
|
return f"data:{mime_type};base64,{encoding}"
|
@ -3,6 +3,9 @@ from typing import Any, List, Union
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from langchain_core._api.deprecation import (
|
||||||
|
LangChainPendingDeprecationWarning,
|
||||||
|
)
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
@ -243,6 +246,7 @@ def test_chat_valid_infer_variables() -> None:
|
|||||||
|
|
||||||
def test_chat_from_role_strings() -> None:
|
def test_chat_from_role_strings() -> None:
|
||||||
"""Test instantiation of chat template from role strings."""
|
"""Test instantiation of chat template from role strings."""
|
||||||
|
with pytest.warns(LangChainPendingDeprecationWarning):
|
||||||
template = ChatPromptTemplate.from_role_strings(
|
template = ChatPromptTemplate.from_role_strings(
|
||||||
[
|
[
|
||||||
("system", "You are a bot."),
|
("system", "You are a bot."),
|
||||||
@ -363,6 +367,136 @@ def test_chat_message_partial() -> None:
|
|||||||
assert template2.format(input="hello") == get_buffer_string(expected)
|
assert template2.format(input="hello") == get_buffer_string(expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_tmpl_from_messages_multipart_text() -> None:
|
||||||
|
template = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
("system", "You are an AI assistant named {name}."),
|
||||||
|
(
|
||||||
|
"human",
|
||||||
|
[
|
||||||
|
{"type": "text", "text": "What's in this image?"},
|
||||||
|
{"type": "text", "text": "Oh nvm"},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
messages = template.format_messages(name="R2D2")
|
||||||
|
expected = [
|
||||||
|
SystemMessage(content="You are an AI assistant named R2D2."),
|
||||||
|
HumanMessage(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "What's in this image?"},
|
||||||
|
{"type": "text", "text": "Oh nvm"},
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
assert messages == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_tmpl_from_messages_multipart_text_with_template() -> None:
|
||||||
|
template = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
("system", "You are an AI assistant named {name}."),
|
||||||
|
(
|
||||||
|
"human",
|
||||||
|
[
|
||||||
|
{"type": "text", "text": "What's in this {object_name}?"},
|
||||||
|
{"type": "text", "text": "Oh nvm"},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
messages = template.format_messages(name="R2D2", object_name="image")
|
||||||
|
expected = [
|
||||||
|
SystemMessage(content="You are an AI assistant named R2D2."),
|
||||||
|
HumanMessage(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "What's in this image?"},
|
||||||
|
{"type": "text", "text": "Oh nvm"},
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
assert messages == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_tmpl_from_messages_multipart_image() -> None:
|
||||||
|
base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA"
|
||||||
|
other_base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA"
|
||||||
|
template = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
("system", "You are an AI assistant named {name}."),
|
||||||
|
(
|
||||||
|
"human",
|
||||||
|
[
|
||||||
|
{"type": "text", "text": "What's in this image?"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": "data:image/jpeg;base64,{my_image}",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": "data:image/jpeg;base64,{my_image}"},
|
||||||
|
},
|
||||||
|
{"type": "image_url", "image_url": "{my_other_image}"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": "{my_other_image}", "detail": "medium"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": "https://www.langchain.com/image.png"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": ""},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
messages = template.format_messages(
|
||||||
|
name="R2D2", my_image=base64_image, my_other_image=other_base64_image
|
||||||
|
)
|
||||||
|
expected = [
|
||||||
|
SystemMessage(content="You are an AI assistant named R2D2."),
|
||||||
|
HumanMessage(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "What's in this image?"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{other_base64_image}"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"{other_base64_image}"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"{other_base64_image}",
|
||||||
|
"detail": "medium",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": "https://www.langchain.com/image.png"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": ""},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
assert messages == expected
|
||||||
|
|
||||||
|
|
||||||
def test_messages_placeholder() -> None:
|
def test_messages_placeholder() -> None:
|
||||||
prompt = MessagesPlaceholder("history")
|
prompt = MessagesPlaceholder("history")
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
|
@ -16,6 +16,7 @@ EXPECTED_ALL = [
|
|||||||
"xor_args",
|
"xor_args",
|
||||||
"try_load_from_hub",
|
"try_load_from_hub",
|
||||||
"build_extra_kwargs",
|
"build_extra_kwargs",
|
||||||
|
"image",
|
||||||
"get_from_dict_or_env",
|
"get_from_dict_or_env",
|
||||||
"get_from_env",
|
"get_from_env",
|
||||||
"stringify_dict",
|
"stringify_dict",
|
||||||
|
Loading…
Reference in New Issue
Block a user