trace images in OAI chat completions format

This commit is contained in:
Chester Curme 2025-04-09 09:46:06 -04:00
parent 4c23ceb9ef
commit 0354dec091
4 changed files with 71 additions and 4 deletions

View File

@ -52,7 +52,9 @@ from langchain_core.messages import (
BaseMessage, BaseMessage,
BaseMessageChunk, BaseMessageChunk,
HumanMessage, HumanMessage,
convert_image_content_block_to_image_url,
convert_to_messages, convert_to_messages,
is_data_content_block,
message_chunk_to_message, message_chunk_to_message,
) )
from langchain_core.outputs import ( from langchain_core.outputs import (
@ -103,6 +105,36 @@ def _generate_response_from_error(error: BaseException) -> list[ChatGeneration]:
return generations return generations
def _format_for_tracing(messages: list[BaseMessage]) -> list[BaseMessage]:
"""Format messages for tracing in on_chat_model_start.
For backward compatibility, we update image content blocks to OpenAI Chat
Completions format.
Args:
messages: List of messages to format.
Returns:
List of messages formatted for tracing.
"""
messages_to_trace = []
for message in messages:
message_to_trace = message
if isinstance(message.content, list):
for idx, block in enumerate(message.content):
if (
isinstance(block, dict)
and is_data_content_block(block)
and block.get("type") == "image"
):
message_to_trace = message.model_copy(deep=True)
message_to_trace.content[idx] = ( # type: ignore[index] # mypy confused by .model_copy
convert_image_content_block_to_image_url(block) # type: ignore[arg-type]
)
messages_to_trace.append(message_to_trace)
return messages_to_trace
def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult: def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
"""Generate from a stream. """Generate from a stream.
@ -439,7 +471,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
) )
(run_manager,) = callback_manager.on_chat_model_start( (run_manager,) = callback_manager.on_chat_model_start(
self._serialized, self._serialized,
[messages], [_format_for_tracing(messages)],
invocation_params=params, invocation_params=params,
options=options, options=options,
name=config.get("run_name"), name=config.get("run_name"),
@ -524,7 +556,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
) )
(run_manager,) = await callback_manager.on_chat_model_start( (run_manager,) = await callback_manager.on_chat_model_start(
self._serialized, self._serialized,
[messages], [_format_for_tracing(messages)],
invocation_params=params, invocation_params=params,
options=options, options=options,
name=config.get("run_name"), name=config.get("run_name"),
@ -703,9 +735,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
inheritable_metadata, inheritable_metadata,
self.metadata, self.metadata,
) )
messages_to_trace = [
_format_for_tracing(message_list) for message_list in messages
]
run_managers = callback_manager.on_chat_model_start( run_managers = callback_manager.on_chat_model_start(
self._serialized, self._serialized,
messages, messages_to_trace,
invocation_params=params, invocation_params=params,
options=options, options=options,
name=run_name, name=run_name,
@ -812,9 +847,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
self.metadata, self.metadata,
) )
messages_to_trace = [
_format_for_tracing(message_list) for message_list in messages
]
run_managers = await callback_manager.on_chat_model_start( run_managers = await callback_manager.on_chat_model_start(
self._serialized, self._serialized,
messages, messages_to_trace,
invocation_params=params, invocation_params=params,
options=options, options=options,
name=run_name, name=run_name,

View File

@ -29,6 +29,7 @@ from langchain_core.messages.base import (
from langchain_core.messages.chat import ChatMessage, ChatMessageChunk from langchain_core.messages.chat import ChatMessage, ChatMessageChunk
from langchain_core.messages.content_blocks import ( from langchain_core.messages.content_blocks import (
DataContentBlock, DataContentBlock,
convert_image_content_block_to_image_url,
is_data_content_block, is_data_content_block,
) )
from langchain_core.messages.function import FunctionMessage, FunctionMessageChunk from langchain_core.messages.function import FunctionMessage, FunctionMessageChunk
@ -79,6 +80,7 @@ __all__ = [
"ToolMessageChunk", "ToolMessageChunk",
"RemoveMessage", "RemoveMessage",
"_message_from_dict", "_message_from_dict",
"convert_image_content_block_to_image_url",
"convert_to_messages", "convert_to_messages",
"get_buffer_string", "get_buffer_string",
"is_data_content_block", "is_data_content_block",

View File

@ -33,3 +33,27 @@ def is_data_content_block(
""" """
required_keys = DataContentBlock.__required_keys__ required_keys = DataContentBlock.__required_keys__
return all(required_key in content_block for required_key in required_keys) return all(required_key in content_block for required_key in required_keys)
def convert_image_content_block_to_image_url(content_block: DataContentBlock) -> dict:
"""Convert image content block to format expected by OpenAI Chat Completions API."""
if content_block["source_type"] == "url":
return {
"type": "image_url",
"image_url": {
"url": content_block["source"],
},
}
if content_block["source_type"] == "base64":
if "mime_type" not in content_block:
error_message = "mime_type key is required for base64 data."
raise ValueError(error_message)
mime_type = content_block["mime_type"]
return {
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{content_block['source']}",
},
}
error_message = "Unsupported source type. Only 'url' and 'base64' are supported."
raise ValueError(error_message)

View File

@ -10,6 +10,7 @@ EXPECTED_ALL = [
"BaseMessageChunk", "BaseMessageChunk",
"ChatMessage", "ChatMessage",
"ChatMessageChunk", "ChatMessageChunk",
"DataContentBlock",
"FunctionMessage", "FunctionMessage",
"FunctionMessageChunk", "FunctionMessageChunk",
"HumanMessage", "HumanMessage",
@ -24,6 +25,7 @@ EXPECTED_ALL = [
"RemoveMessage", "RemoveMessage",
"convert_to_messages", "convert_to_messages",
"get_buffer_string", "get_buffer_string",
"is_data_content_block",
"merge_content", "merge_content",
"message_chunk_to_message", "message_chunk_to_message",
"message_to_dict", "message_to_dict",
@ -32,6 +34,7 @@ EXPECTED_ALL = [
"filter_messages", "filter_messages",
"merge_message_runs", "merge_message_runs",
"trim_messages", "trim_messages",
"convert_image_content_block_to_image_url",
"convert_to_openai_messages", "convert_to_openai_messages",
] ]