trace images in OAI chat completions format

This commit is contained in:
Chester Curme 2025-04-09 09:46:06 -04:00
parent 4c23ceb9ef
commit 0354dec091
4 changed files with 71 additions and 4 deletions

View File

@ -52,7 +52,9 @@ from langchain_core.messages import (
BaseMessage,
BaseMessageChunk,
HumanMessage,
convert_image_content_block_to_image_url,
convert_to_messages,
is_data_content_block,
message_chunk_to_message,
)
from langchain_core.outputs import (
@ -103,6 +105,36 @@ def _generate_response_from_error(error: BaseException) -> list[ChatGeneration]:
return generations
def _format_for_tracing(messages: list[BaseMessage]) -> list[BaseMessage]:
"""Format messages for tracing in on_chat_model_start.
For backward compatibility, we update image content blocks to OpenAI Chat
Completions format.
Args:
messages: List of messages to format.
Returns:
List of messages formatted for tracing.
"""
messages_to_trace = []
for message in messages:
message_to_trace = message
if isinstance(message.content, list):
for idx, block in enumerate(message.content):
if (
isinstance(block, dict)
and is_data_content_block(block)
and block.get("type") == "image"
):
message_to_trace = message.model_copy(deep=True)
message_to_trace.content[idx] = ( # type: ignore[index] # mypy confused by .model_copy
convert_image_content_block_to_image_url(block) # type: ignore[arg-type]
)
messages_to_trace.append(message_to_trace)
return messages_to_trace
def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
"""Generate from a stream.
@ -439,7 +471,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
)
(run_manager,) = callback_manager.on_chat_model_start(
self._serialized,
[messages],
[_format_for_tracing(messages)],
invocation_params=params,
options=options,
name=config.get("run_name"),
@ -524,7 +556,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
)
(run_manager,) = await callback_manager.on_chat_model_start(
self._serialized,
[messages],
[_format_for_tracing(messages)],
invocation_params=params,
options=options,
name=config.get("run_name"),
@ -703,9 +735,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
inheritable_metadata,
self.metadata,
)
messages_to_trace = [
_format_for_tracing(message_list) for message_list in messages
]
run_managers = callback_manager.on_chat_model_start(
self._serialized,
messages,
messages_to_trace,
invocation_params=params,
options=options,
name=run_name,
@ -812,9 +847,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
self.metadata,
)
messages_to_trace = [
_format_for_tracing(message_list) for message_list in messages
]
run_managers = await callback_manager.on_chat_model_start(
self._serialized,
messages,
messages_to_trace,
invocation_params=params,
options=options,
name=run_name,

View File

@ -29,6 +29,7 @@ from langchain_core.messages.base import (
from langchain_core.messages.chat import ChatMessage, ChatMessageChunk
from langchain_core.messages.content_blocks import (
DataContentBlock,
convert_image_content_block_to_image_url,
is_data_content_block,
)
from langchain_core.messages.function import FunctionMessage, FunctionMessageChunk
@ -79,6 +80,7 @@ __all__ = [
"ToolMessageChunk",
"RemoveMessage",
"_message_from_dict",
"convert_image_content_block_to_image_url",
"convert_to_messages",
"get_buffer_string",
"is_data_content_block",

View File

@ -33,3 +33,27 @@ def is_data_content_block(
"""
required_keys = DataContentBlock.__required_keys__
return all(required_key in content_block for required_key in required_keys)
def convert_image_content_block_to_image_url(content_block: DataContentBlock) -> dict:
"""Convert image content block to format expected by OpenAI Chat Completions API."""
if content_block["source_type"] == "url":
return {
"type": "image_url",
"image_url": {
"url": content_block["source"],
},
}
if content_block["source_type"] == "base64":
if "mime_type" not in content_block:
error_message = "mime_type key is required for base64 data."
raise ValueError(error_message)
mime_type = content_block["mime_type"]
return {
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{content_block['source']}",
},
}
error_message = "Unsupported source type. Only 'url' and 'base64' are supported."
raise ValueError(error_message)

View File

@ -10,6 +10,7 @@ EXPECTED_ALL = [
"BaseMessageChunk",
"ChatMessage",
"ChatMessageChunk",
"DataContentBlock",
"FunctionMessage",
"FunctionMessageChunk",
"HumanMessage",
@ -24,6 +25,7 @@ EXPECTED_ALL = [
"RemoveMessage",
"convert_to_messages",
"get_buffer_string",
"is_data_content_block",
"merge_content",
"message_chunk_to_message",
"message_to_dict",
@ -32,6 +34,7 @@ EXPECTED_ALL = [
"filter_messages",
"merge_message_runs",
"trim_messages",
"convert_image_content_block_to_image_url",
"convert_to_openai_messages",
]