mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-09 02:33:34 +00:00
Compare commits
8 Commits
cc/model_k
...
cc/multi_m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f3e23022ef | ||
|
|
7e3caec720 | ||
|
|
35fbe24532 | ||
|
|
0354dec091 | ||
|
|
4c23ceb9ef | ||
|
|
b1fc20cbcd | ||
|
|
cbd05c66de | ||
|
|
99646c143d |
@@ -102,6 +102,12 @@ def test_serializable_mapping() -> None:
|
||||
"modifier",
|
||||
"RemoveMessage",
|
||||
),
|
||||
("langchain", "prompts", "data", "DataPromptTemplate"): (
|
||||
"langchain_core",
|
||||
"prompts",
|
||||
"data",
|
||||
"DataPromptTemplate",
|
||||
),
|
||||
("langchain", "chat_models", "mistralai", "ChatMistralAI"): (
|
||||
"langchain_mistralai",
|
||||
"chat_models",
|
||||
|
||||
@@ -52,7 +52,9 @@ from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
convert_image_content_block_to_image_url,
|
||||
convert_to_messages,
|
||||
is_data_content_block,
|
||||
message_chunk_to_message,
|
||||
)
|
||||
from langchain_core.outputs import (
|
||||
@@ -103,6 +105,36 @@ def _generate_response_from_error(error: BaseException) -> list[ChatGeneration]:
|
||||
return generations
|
||||
|
||||
|
||||
def _format_for_tracing(messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||
"""Format messages for tracing in on_chat_model_start.
|
||||
|
||||
For backward compatibility, we update image content blocks to OpenAI Chat
|
||||
Completions format.
|
||||
|
||||
Args:
|
||||
messages: List of messages to format.
|
||||
|
||||
Returns:
|
||||
List of messages formatted for tracing.
|
||||
"""
|
||||
messages_to_trace = []
|
||||
for message in messages:
|
||||
message_to_trace = message
|
||||
if isinstance(message.content, list):
|
||||
for idx, block in enumerate(message.content):
|
||||
if (
|
||||
isinstance(block, dict)
|
||||
and is_data_content_block(block)
|
||||
and block.get("type") == "image"
|
||||
):
|
||||
message_to_trace = message.model_copy(deep=True)
|
||||
message_to_trace.content[idx] = ( # type: ignore[index] # mypy confused by .model_copy
|
||||
convert_image_content_block_to_image_url(block) # type: ignore[arg-type]
|
||||
)
|
||||
messages_to_trace.append(message_to_trace)
|
||||
return messages_to_trace
|
||||
|
||||
|
||||
def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
|
||||
"""Generate from a stream.
|
||||
|
||||
@@ -439,7 +471,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
)
|
||||
(run_manager,) = callback_manager.on_chat_model_start(
|
||||
self._serialized,
|
||||
[messages],
|
||||
[_format_for_tracing(messages)],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
@@ -524,7 +556,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
)
|
||||
(run_manager,) = await callback_manager.on_chat_model_start(
|
||||
self._serialized,
|
||||
[messages],
|
||||
[_format_for_tracing(messages)],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
@@ -703,9 +735,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
inheritable_metadata,
|
||||
self.metadata,
|
||||
)
|
||||
messages_to_trace = [
|
||||
_format_for_tracing(message_list) for message_list in messages
|
||||
]
|
||||
run_managers = callback_manager.on_chat_model_start(
|
||||
self._serialized,
|
||||
messages,
|
||||
messages_to_trace,
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=run_name,
|
||||
@@ -812,9 +847,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
self.metadata,
|
||||
)
|
||||
|
||||
messages_to_trace = [
|
||||
_format_for_tracing(message_list) for message_list in messages
|
||||
]
|
||||
run_managers = await callback_manager.on_chat_model_start(
|
||||
self._serialized,
|
||||
messages,
|
||||
messages_to_trace,
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=run_name,
|
||||
|
||||
@@ -146,6 +146,12 @@ SERIALIZABLE_MAPPING: dict[tuple[str, ...], tuple[str, ...]] = {
|
||||
"image",
|
||||
"ImagePromptTemplate",
|
||||
),
|
||||
("langchain", "prompts", "data", "DataPromptTemplate"): (
|
||||
"langchain_core",
|
||||
"prompts",
|
||||
"data",
|
||||
"DataPromptTemplate",
|
||||
),
|
||||
("langchain", "schema", "agent", "AgentActionMessageLog"): (
|
||||
"langchain_core",
|
||||
"agents",
|
||||
|
||||
@@ -27,6 +27,11 @@ from langchain_core.messages.base import (
|
||||
messages_to_dict,
|
||||
)
|
||||
from langchain_core.messages.chat import ChatMessage, ChatMessageChunk
|
||||
from langchain_core.messages.content_blocks import (
|
||||
DataContentBlock,
|
||||
convert_image_content_block_to_image_url,
|
||||
is_data_content_block,
|
||||
)
|
||||
from langchain_core.messages.function import FunctionMessage, FunctionMessageChunk
|
||||
from langchain_core.messages.human import HumanMessage, HumanMessageChunk
|
||||
from langchain_core.messages.modifier import RemoveMessage
|
||||
@@ -60,6 +65,7 @@ __all__ = [
|
||||
"BaseMessageChunk",
|
||||
"ChatMessage",
|
||||
"ChatMessageChunk",
|
||||
"DataContentBlock",
|
||||
"FunctionMessage",
|
||||
"FunctionMessageChunk",
|
||||
"HumanMessage",
|
||||
@@ -74,8 +80,10 @@ __all__ = [
|
||||
"ToolMessageChunk",
|
||||
"RemoveMessage",
|
||||
"_message_from_dict",
|
||||
"convert_image_content_block_to_image_url",
|
||||
"convert_to_messages",
|
||||
"get_buffer_string",
|
||||
"is_data_content_block",
|
||||
"merge_content",
|
||||
"message_chunk_to_message",
|
||||
"message_to_dict",
|
||||
|
||||
59
libs/core/langchain_core/messages/content_blocks.py
Normal file
59
libs/core/langchain_core/messages/content_blocks.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Types for content blocks."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class DataContentBlock(TypedDict, total=False):
|
||||
"""Data content block."""
|
||||
|
||||
type: Required[Literal["image", "audio", "file"]]
|
||||
"""Type of the content block."""
|
||||
source_type: Required[Literal["url", "base64", "id", "text"]]
|
||||
"""Source type."""
|
||||
source: Required[str]
|
||||
"""Data as a URL or data-URI, identifier, or plain-text."""
|
||||
mime_type: str
|
||||
"""MIME type of the content block (if block represents base64 data.)"""
|
||||
metadata: dict
|
||||
"""Provider-specific metadata such as citations or filenames."""
|
||||
|
||||
|
||||
def is_data_content_block(
|
||||
content_block: dict,
|
||||
) -> bool:
|
||||
"""Check if the content block is a data content block.
|
||||
|
||||
Args:
|
||||
content_block: The content block to check.
|
||||
|
||||
Returns:
|
||||
True if the content block is a data content block, False otherwise.
|
||||
"""
|
||||
required_keys = DataContentBlock.__required_keys__
|
||||
return all(required_key in content_block for required_key in required_keys)
|
||||
|
||||
|
||||
def convert_image_content_block_to_image_url(content_block: DataContentBlock) -> dict:
|
||||
"""Convert image content block to format expected by OpenAI Chat Completions API."""
|
||||
if content_block["source_type"] == "url":
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": content_block["source"],
|
||||
},
|
||||
}
|
||||
if content_block["source_type"] == "base64":
|
||||
if "mime_type" not in content_block:
|
||||
error_message = "mime_type key is required for base64 data."
|
||||
raise ValueError(error_message)
|
||||
mime_type = content_block["mime_type"]
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{mime_type};base64,{content_block['source']}",
|
||||
},
|
||||
}
|
||||
error_message = "Unsupported source type. Only 'url' and 'base64' are supported."
|
||||
raise ValueError(error_message)
|
||||
@@ -16,6 +16,7 @@ from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.messages import (
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
DataContentBlock,
|
||||
HumanMessage,
|
||||
get_buffer_string,
|
||||
)
|
||||
@@ -130,6 +131,22 @@ class ImagePromptValue(PromptValue):
|
||||
return [HumanMessage(content=[cast("dict", self.image_url)])]
|
||||
|
||||
|
||||
class DataPromptValue(PromptValue):
|
||||
"""Prompt value for multi-modal data."""
|
||||
|
||||
content_block: DataContentBlock
|
||||
"""Multi-modal content block."""
|
||||
type: Literal["DataPromptValue"] = "DataPromptValue"
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Return source data as a string."""
|
||||
return self.content_block["source"]
|
||||
|
||||
def to_messages(self) -> list[BaseMessage]:
|
||||
"""Return prompt (image URL) as messages."""
|
||||
return [HumanMessage(content=[cast("dict", self.content_block)])]
|
||||
|
||||
|
||||
class ChatPromptValueConcrete(ChatPromptValue):
|
||||
"""Chat prompt value which explicitly lists out the message types it accepts.
|
||||
|
||||
|
||||
@@ -31,13 +31,16 @@ from langchain_core.messages import (
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
DataContentBlock,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
convert_to_messages,
|
||||
is_data_content_block,
|
||||
)
|
||||
from langchain_core.messages.base import get_msg_title_repr
|
||||
from langchain_core.prompt_values import ChatPromptValue, ImageURL, PromptValue
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.data import DataPromptTemplate
|
||||
from langchain_core.prompts.image import ImagePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompts.string import (
|
||||
@@ -468,7 +471,8 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
"""Human message prompt template. This is a message sent from the user."""
|
||||
|
||||
prompt: Union[
|
||||
StringPromptTemplate, list[Union[StringPromptTemplate, ImagePromptTemplate]]
|
||||
StringPromptTemplate,
|
||||
list[Union[StringPromptTemplate, ImagePromptTemplate, DataPromptTemplate]],
|
||||
]
|
||||
"""Prompt template."""
|
||||
additional_kwargs: dict = Field(default_factory=dict)
|
||||
@@ -479,7 +483,10 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls: type[Self],
|
||||
template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
|
||||
template: Union[
|
||||
str,
|
||||
list[Union[str, _TextTemplateParam, _ImageTemplateParam, DataContentBlock]],
|
||||
],
|
||||
template_format: PromptTemplateFormat = "f-string",
|
||||
*,
|
||||
partial_variables: Optional[dict[str, Any]] = None,
|
||||
@@ -562,6 +569,23 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
msg = f"Invalid image template: {tmpl}"
|
||||
raise ValueError(msg)
|
||||
prompt.append(img_template_obj)
|
||||
elif isinstance(tmpl, dict) and is_data_content_block(tmpl): # type: ignore[arg-type]
|
||||
data_template = cast("DataContentBlock", tmpl)
|
||||
input_variables = []
|
||||
for key in ["source", "source_type", "mime_type"]:
|
||||
if key in data_template:
|
||||
input_variables.extend(
|
||||
get_template_variables(
|
||||
data_template[key], # type: ignore[literal-required]
|
||||
template_format,
|
||||
)
|
||||
)
|
||||
data_template_obj = DataPromptTemplate(
|
||||
input_variables=input_variables,
|
||||
template=data_template,
|
||||
template_format=template_format,
|
||||
)
|
||||
prompt.append(data_template_obj)
|
||||
else:
|
||||
msg = f"Invalid template: {tmpl}"
|
||||
raise ValueError(msg)
|
||||
@@ -639,11 +663,16 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
for prompt in self.prompt:
|
||||
inputs = {var: kwargs[var] for var in prompt.input_variables}
|
||||
if isinstance(prompt, StringPromptTemplate):
|
||||
formatted: Union[str, ImageURL] = prompt.format(**inputs)
|
||||
formatted: Union[str, ImageURL, DataContentBlock] = prompt.format(
|
||||
**inputs
|
||||
)
|
||||
content.append({"type": "text", "text": formatted})
|
||||
elif isinstance(prompt, ImagePromptTemplate):
|
||||
formatted = prompt.format(**inputs)
|
||||
content.append({"type": "image_url", "image_url": formatted})
|
||||
elif isinstance(prompt, DataPromptTemplate):
|
||||
formatted = prompt.format(**inputs)
|
||||
content.append(formatted)
|
||||
return self._msg_class(
|
||||
content=content, additional_kwargs=self.additional_kwargs
|
||||
)
|
||||
|
||||
145
libs/core/langchain_core/prompts/data.py
Normal file
145
libs/core/langchain_core/prompts/data.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Image prompt template for a multimodal model."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_core.messages import DataContentBlock
|
||||
from langchain_core.prompt_values import DataPromptValue, PromptValue
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts.string import (
|
||||
DEFAULT_FORMATTER_MAPPING,
|
||||
PromptTemplateFormat,
|
||||
)
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
||||
|
||||
class DataPromptTemplate(BasePromptTemplate[DataContentBlock]):
|
||||
"""Prompt template for a multi-modal model."""
|
||||
|
||||
template: dict = Field(default_factory=dict)
|
||||
"""Template for the prompt."""
|
||||
template_format: PromptTemplateFormat = "f-string"
|
||||
"""The format of the prompt template.
|
||||
Options are: 'f-string', 'mustache', 'jinja2'."""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Create a prompt template for multi-modal data."""
|
||||
if "input_variables" not in kwargs:
|
||||
kwargs["input_variables"] = []
|
||||
|
||||
overlap = set(kwargs["input_variables"]) & {
|
||||
"source",
|
||||
"source_type",
|
||||
"mime_type",
|
||||
"metadata",
|
||||
}
|
||||
if overlap:
|
||||
msg = (
|
||||
"input_variables for the template cannot contain"
|
||||
" any of 'source', 'source_type', 'mime_type', or 'metadata'."
|
||||
f" Found: {overlap}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
"""Return the prompt type key."""
|
||||
return "data-prompt"
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> list[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "prompts", "data"]
|
||||
|
||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
"""Format the prompt with the inputs.
|
||||
|
||||
Args:
|
||||
kwargs: Any arguments to be passed to the prompt template.
|
||||
|
||||
Returns:
|
||||
A formatted string.
|
||||
"""
|
||||
return DataPromptValue(content_block=self.format(**kwargs))
|
||||
|
||||
async def aformat_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
"""Async format the prompt with the inputs.
|
||||
|
||||
Args:
|
||||
kwargs: Any arguments to be passed to the prompt template.
|
||||
|
||||
Returns:
|
||||
A formatted string.
|
||||
"""
|
||||
return DataPromptValue(content_block=await self.aformat(**kwargs))
|
||||
|
||||
def format(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> DataContentBlock:
|
||||
"""Format the prompt with the inputs.
|
||||
|
||||
Args:
|
||||
kwargs: Any arguments to be passed to the prompt template.
|
||||
|
||||
Returns:
|
||||
A formatted string.
|
||||
|
||||
Raises:
|
||||
ValueError: If the url is not provided.
|
||||
ValueError: If the url is not a string.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prompt.format(variable1="foo")
|
||||
"""
|
||||
formatted = {}
|
||||
for k, v in self.template.items():
|
||||
if isinstance(v, str):
|
||||
formatted[k] = DEFAULT_FORMATTER_MAPPING[self.template_format](
|
||||
v, **kwargs
|
||||
)
|
||||
else:
|
||||
formatted[k] = v
|
||||
|
||||
block = {}
|
||||
for k in ["type", "source_type", "source", "mime_type", "metadata"]:
|
||||
value = kwargs.get(k) or formatted.get(k)
|
||||
if value:
|
||||
block[k] = value
|
||||
|
||||
for required_field in ["source", "source_type"]:
|
||||
if required_field not in block:
|
||||
msg = f"Missing required field: {required_field}"
|
||||
raise ValueError(msg)
|
||||
|
||||
return cast("DataContentBlock", block)
|
||||
|
||||
async def aformat(self, **kwargs: Any) -> DataContentBlock:
|
||||
"""Async format the prompt with the inputs.
|
||||
|
||||
Args:
|
||||
kwargs: Any arguments to be passed to the prompt template.
|
||||
|
||||
Returns:
|
||||
A formatted string.
|
||||
|
||||
Raises:
|
||||
ValueError: If the path or url is not a string.
|
||||
"""
|
||||
return await run_in_executor(None, self.format, **kwargs)
|
||||
|
||||
def pretty_repr(self, html: bool = False) -> str:
|
||||
"""Return a pretty representation of the prompt.
|
||||
|
||||
Args:
|
||||
html: Whether to return an html formatted string.
|
||||
|
||||
Returns:
|
||||
A pretty representation of the prompt.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -10,6 +10,7 @@ EXPECTED_ALL = [
|
||||
"BaseMessageChunk",
|
||||
"ChatMessage",
|
||||
"ChatMessageChunk",
|
||||
"DataContentBlock",
|
||||
"FunctionMessage",
|
||||
"FunctionMessageChunk",
|
||||
"HumanMessage",
|
||||
@@ -24,6 +25,7 @@ EXPECTED_ALL = [
|
||||
"RemoveMessage",
|
||||
"convert_to_messages",
|
||||
"get_buffer_string",
|
||||
"is_data_content_block",
|
||||
"merge_content",
|
||||
"message_chunk_to_message",
|
||||
"message_to_dict",
|
||||
@@ -32,6 +34,7 @@ EXPECTED_ALL = [
|
||||
"filter_messages",
|
||||
"merge_message_runs",
|
||||
"trim_messages",
|
||||
"convert_image_content_block_to_image_url",
|
||||
"convert_to_openai_messages",
|
||||
]
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from syrupy import SnapshotAssertion
|
||||
from langchain_core._api.deprecation import (
|
||||
LangChainPendingDeprecationWarning,
|
||||
)
|
||||
from langchain_core.load import dumpd, load
|
||||
from langchain_core.load import dump, dumpd, load, loads
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
@@ -1049,3 +1049,75 @@ def test_chat_prompt_template_variable_names() -> None:
|
||||
"title": "PromptInput",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
|
||||
def test_data_prompt_template_deserializable() -> None:
|
||||
"""Test that the image prompt template is serializable."""
|
||||
loads(
|
||||
dump.dumps(
|
||||
ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
[{"type": "image", "source_type": "url", "source": "{url}"}],
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("jinja2")
|
||||
@pytest.mark.parametrize(
|
||||
("template_format", "mime_type_placeholder", "source_data_placeholder"),
|
||||
[
|
||||
("f-string", "{media_type}", "{source_data}"),
|
||||
("mustache", "{{media_type}}", "{{source_data}}"),
|
||||
("jinja2", "{{ media_type }}", "{{ source_data }}"),
|
||||
],
|
||||
)
|
||||
def test_chat_prompt_template_data_prompt_from_message(
|
||||
template_format: PromptTemplateFormat,
|
||||
mime_type_placeholder: str,
|
||||
source_data_placeholder: str,
|
||||
) -> None:
|
||||
prompt = {
|
||||
"type": "image",
|
||||
"source_type": "base64",
|
||||
"source": f"{source_data_placeholder}",
|
||||
}
|
||||
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
[("human", [prompt])], template_format=template_format
|
||||
)
|
||||
assert template.format_messages(source_data="base64data") == [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "image",
|
||||
"source_type": "base64",
|
||||
"source": "base64data",
|
||||
}
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
# mime_type
|
||||
prompt["mime_type"] = f"{mime_type_placeholder}"
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
[("human", [prompt])], template_format=template_format
|
||||
)
|
||||
assert template.format_messages(
|
||||
media_type="image/png", source_data="base64data"
|
||||
) == [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "image",
|
||||
"source_type": "base64",
|
||||
"source": "base64data",
|
||||
"mime_type": "image/png",
|
||||
}
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
@@ -42,6 +42,7 @@ from langchain_core.messages import (
|
||||
SystemMessage,
|
||||
ToolCall,
|
||||
ToolMessage,
|
||||
is_data_content_block,
|
||||
)
|
||||
from langchain_core.messages.ai import InputTokenDetails, UsageMetadata
|
||||
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
||||
@@ -184,8 +185,75 @@ def _merge_messages(
|
||||
return merged
|
||||
|
||||
|
||||
def _format_data_content_block(block: dict) -> dict:
|
||||
"""Format standard data content block to format expected by Anthropic."""
|
||||
if block["type"] == "image":
|
||||
if block["source_type"] == "url":
|
||||
if block["source"].startswith("data:"):
|
||||
# Data URI
|
||||
formatted_block = {
|
||||
"type": "image",
|
||||
"source": _format_image(block["source"]),
|
||||
}
|
||||
else:
|
||||
formatted_block = {
|
||||
"type": "image",
|
||||
"source": {"type": "url", "url": block["source"]},
|
||||
}
|
||||
elif block["source_type"] == "base64":
|
||||
formatted_block = {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": block["mime_type"],
|
||||
"data": block["source"],
|
||||
},
|
||||
}
|
||||
else:
|
||||
raise ValueError(
|
||||
"Anthropic only supports 'url' and 'base64' source_type for image "
|
||||
"content blocks."
|
||||
)
|
||||
|
||||
elif block["type"] == "file":
|
||||
if block["source_type"] == "url":
|
||||
formatted_block = {
|
||||
"type": "document",
|
||||
"source": {
|
||||
"type": "url",
|
||||
"url": block["source"],
|
||||
},
|
||||
}
|
||||
elif block["source_type"] == "base64":
|
||||
formatted_block = {
|
||||
"type": "document",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": block.get("mime_type") or "application/pdf",
|
||||
"data": block["source"],
|
||||
},
|
||||
}
|
||||
elif block["source_type"] == "text":
|
||||
formatted_block = {
|
||||
"type": "document",
|
||||
"source": {
|
||||
"type": "text",
|
||||
"media_type": block.get("mime_type") or "text/plain",
|
||||
"data": block["source"],
|
||||
},
|
||||
}
|
||||
|
||||
else:
|
||||
raise ValueError(f"Block of type {block['type']} is not supported.")
|
||||
|
||||
if formatted_block and (metadata := block.get("metadata")):
|
||||
formatted_block = {**formatted_block, **metadata}
|
||||
|
||||
return formatted_block
|
||||
|
||||
|
||||
def _format_messages(
|
||||
messages: List[BaseMessage],
|
||||
messages: Sequence[BaseMessage],
|
||||
) -> Tuple[Union[str, List[Dict], None], List[Dict]]:
|
||||
"""Format messages for anthropic."""
|
||||
|
||||
@@ -240,6 +308,8 @@ def _format_messages(
|
||||
# convert format
|
||||
source = _format_image(block["image_url"]["url"])
|
||||
content.append({"type": "image", "source": source})
|
||||
elif is_data_content_block(block):
|
||||
content.append(_format_data_content_block(block))
|
||||
elif block["type"] == "tool_use":
|
||||
# If a tool_call with the same id as a tool_use content block
|
||||
# exists, the tool_call is preferred.
|
||||
|
||||
@@ -690,6 +690,85 @@ def test__format_messages_with_cache_control() -> None:
|
||||
assert expected_system == actual_system
|
||||
assert expected_messages == actual_messages
|
||||
|
||||
# Test standard multi-modal format
|
||||
messages = [
|
||||
HumanMessage(
|
||||
[
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Summarize this document:",
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"source_type": "base64",
|
||||
"mime_type": "application/pdf",
|
||||
"source": "<base64 data>",
|
||||
"metadata": {"cache_control": {"type": "ephemeral"}},
|
||||
},
|
||||
]
|
||||
)
|
||||
]
|
||||
actual_system, actual_messages = _format_messages(messages)
|
||||
assert actual_system is None
|
||||
expected_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Summarize this document:",
|
||||
},
|
||||
{
|
||||
"type": "document",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "application/pdf",
|
||||
"data": "<base64 data>",
|
||||
},
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
assert actual_messages == expected_messages
|
||||
|
||||
|
||||
def test__format_messages_with_citations() -> None:
|
||||
input_messages = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "file",
|
||||
"source_type": "text",
|
||||
"source": "The grass is green. The sky is blue.",
|
||||
"mime_type": "text/plain",
|
||||
"metadata": {"citations": {"enabled": True}},
|
||||
},
|
||||
{"type": "text", "text": "What color is the grass and sky?"},
|
||||
]
|
||||
)
|
||||
]
|
||||
expected_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "document",
|
||||
"source": {
|
||||
"type": "text",
|
||||
"media_type": "text/plain",
|
||||
"data": "The grass is green. The sky is blue.",
|
||||
},
|
||||
"citations": {"enabled": True},
|
||||
},
|
||||
{"type": "text", "text": "What color is the grass and sky?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
actual_system, actual_messages = _format_messages(input_messages)
|
||||
assert actual_system is None
|
||||
assert actual_messages == expected_messages
|
||||
|
||||
|
||||
def test__format_messages_with_multiple_system() -> None:
|
||||
messages = [
|
||||
|
||||
@@ -68,6 +68,8 @@ from langchain_core.messages import (
|
||||
ToolCall,
|
||||
ToolMessage,
|
||||
ToolMessageChunk,
|
||||
convert_image_content_block_to_image_url,
|
||||
is_data_content_block,
|
||||
)
|
||||
from langchain_core.messages.ai import (
|
||||
InputTokenDetails,
|
||||
@@ -191,6 +193,29 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
return ChatMessage(content=_dict.get("content", ""), role=role, id=id_) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _format_data_content_block(block: dict) -> dict:
|
||||
"""Format standard data content block to format expected by OpenAI."""
|
||||
if block["type"] == "image":
|
||||
formatted_block = convert_image_content_block_to_image_url(block) # type: ignore[arg-type]
|
||||
|
||||
elif block["type"] == "file":
|
||||
if block["source_type"] == "base64":
|
||||
file = {"file_data": f"data:{block['mime_type']};base64,{block['source']}"}
|
||||
if metadata := block.get("metadata"):
|
||||
file = {**file, **metadata}
|
||||
# Hack to support cross-compatibility with providers that do not require
|
||||
# filename (OpenAI requires one).
|
||||
if "filename" not in file:
|
||||
file["filename"] = ""
|
||||
formatted_block = {"type": "file", "file": file}
|
||||
elif block["source_type"] == "id":
|
||||
formatted_block = {"type": "file", "file": {"file_id": block["source"]}}
|
||||
else:
|
||||
raise ValueError(f"Block of type {block['type']} is not supported.")
|
||||
|
||||
return formatted_block
|
||||
|
||||
|
||||
def _format_message_content(content: Any) -> Any:
|
||||
"""Format message content."""
|
||||
if content and isinstance(content, list):
|
||||
@@ -203,6 +228,8 @@ def _format_message_content(content: Any) -> Any:
|
||||
and block["type"] in ("tool_use", "thinking")
|
||||
):
|
||||
continue
|
||||
elif is_data_content_block(block):
|
||||
formatted_content.append(_format_data_content_block(block))
|
||||
# Anthropic image blocks
|
||||
elif (
|
||||
isinstance(block, dict)
|
||||
|
||||
@@ -1891,6 +1891,26 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
result = model_with_tools.invoke(messages)
|
||||
assert isinstance(result, AIMessage)
|
||||
|
||||
def test_pdf_inputs(self, model: BaseChatModel) -> None:
|
||||
url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
|
||||
pdf_data = base64.b64encode(httpx.get(url).content).decode("utf-8")
|
||||
|
||||
message = HumanMessage(
|
||||
[
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Summarize this document:",
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"source_type": "base64",
|
||||
"mime_type": "application/pdf",
|
||||
"source": pdf_data,
|
||||
},
|
||||
]
|
||||
)
|
||||
_ = model.invoke([message])
|
||||
|
||||
def test_image_inputs(self, model: BaseChatModel) -> None:
|
||||
"""Test that the model can process image inputs.
|
||||
|
||||
@@ -1932,6 +1952,8 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
pytest.skip("Model does not support image message.")
|
||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8")
|
||||
|
||||
# OpenAI format, base64 data
|
||||
message = HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "describe the weather in this image"},
|
||||
@@ -1943,6 +1965,33 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
)
|
||||
model.invoke([message])
|
||||
|
||||
# Standard format, base64 data
|
||||
message = HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "describe the weather in this image"},
|
||||
{
|
||||
"type": "image",
|
||||
"source_type": "base64",
|
||||
"mime_type": "image/jpeg",
|
||||
"source": image_data,
|
||||
},
|
||||
],
|
||||
)
|
||||
_ = model.invoke([message])
|
||||
|
||||
# Standard format, URL # TODO: gate this behind a property
|
||||
message = HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "describe the weather in this image"},
|
||||
{
|
||||
"type": "image",
|
||||
"source_type": "url",
|
||||
"source": image_url,
|
||||
},
|
||||
],
|
||||
)
|
||||
_ = model.invoke([message])
|
||||
|
||||
def test_image_tool_message(self, model: BaseChatModel) -> None:
|
||||
"""Test that the model can process ToolMessages with image inputs.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user