mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 00:49:25 +00:00
trace images in OAI chat completions format
This commit is contained in:
parent
4c23ceb9ef
commit
0354dec091
@ -52,7 +52,9 @@ from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
convert_image_content_block_to_image_url,
|
||||
convert_to_messages,
|
||||
is_data_content_block,
|
||||
message_chunk_to_message,
|
||||
)
|
||||
from langchain_core.outputs import (
|
||||
@ -103,6 +105,36 @@ def _generate_response_from_error(error: BaseException) -> list[ChatGeneration]:
|
||||
return generations
|
||||
|
||||
|
||||
def _format_for_tracing(messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||
"""Format messages for tracing in on_chat_model_start.
|
||||
|
||||
For backward compatibility, we update image content blocks to OpenAI Chat
|
||||
Completions format.
|
||||
|
||||
Args:
|
||||
messages: List of messages to format.
|
||||
|
||||
Returns:
|
||||
List of messages formatted for tracing.
|
||||
"""
|
||||
messages_to_trace = []
|
||||
for message in messages:
|
||||
message_to_trace = message
|
||||
if isinstance(message.content, list):
|
||||
for idx, block in enumerate(message.content):
|
||||
if (
|
||||
isinstance(block, dict)
|
||||
and is_data_content_block(block)
|
||||
and block.get("type") == "image"
|
||||
):
|
||||
message_to_trace = message.model_copy(deep=True)
|
||||
message_to_trace.content[idx] = ( # type: ignore[index] # mypy confused by .model_copy
|
||||
convert_image_content_block_to_image_url(block) # type: ignore[arg-type]
|
||||
)
|
||||
messages_to_trace.append(message_to_trace)
|
||||
return messages_to_trace
|
||||
|
||||
|
||||
def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
|
||||
"""Generate from a stream.
|
||||
|
||||
@ -439,7 +471,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
)
|
||||
(run_manager,) = callback_manager.on_chat_model_start(
|
||||
self._serialized,
|
||||
[messages],
|
||||
[_format_for_tracing(messages)],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
@ -524,7 +556,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
)
|
||||
(run_manager,) = await callback_manager.on_chat_model_start(
|
||||
self._serialized,
|
||||
[messages],
|
||||
[_format_for_tracing(messages)],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
@ -703,9 +735,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
inheritable_metadata,
|
||||
self.metadata,
|
||||
)
|
||||
messages_to_trace = [
|
||||
_format_for_tracing(message_list) for message_list in messages
|
||||
]
|
||||
run_managers = callback_manager.on_chat_model_start(
|
||||
self._serialized,
|
||||
messages,
|
||||
messages_to_trace,
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=run_name,
|
||||
@ -812,9 +847,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
self.metadata,
|
||||
)
|
||||
|
||||
messages_to_trace = [
|
||||
_format_for_tracing(message_list) for message_list in messages
|
||||
]
|
||||
run_managers = await callback_manager.on_chat_model_start(
|
||||
self._serialized,
|
||||
messages,
|
||||
messages_to_trace,
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=run_name,
|
||||
|
@ -29,6 +29,7 @@ from langchain_core.messages.base import (
|
||||
from langchain_core.messages.chat import ChatMessage, ChatMessageChunk
|
||||
from langchain_core.messages.content_blocks import (
|
||||
DataContentBlock,
|
||||
convert_image_content_block_to_image_url,
|
||||
is_data_content_block,
|
||||
)
|
||||
from langchain_core.messages.function import FunctionMessage, FunctionMessageChunk
|
||||
@ -79,6 +80,7 @@ __all__ = [
|
||||
"ToolMessageChunk",
|
||||
"RemoveMessage",
|
||||
"_message_from_dict",
|
||||
"convert_image_content_block_to_image_url",
|
||||
"convert_to_messages",
|
||||
"get_buffer_string",
|
||||
"is_data_content_block",
|
||||
|
@ -33,3 +33,27 @@ def is_data_content_block(
|
||||
"""
|
||||
required_keys = DataContentBlock.__required_keys__
|
||||
return all(required_key in content_block for required_key in required_keys)
|
||||
|
||||
|
||||
def convert_image_content_block_to_image_url(content_block: DataContentBlock) -> dict:
|
||||
"""Convert image content block to format expected by OpenAI Chat Completions API."""
|
||||
if content_block["source_type"] == "url":
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": content_block["source"],
|
||||
},
|
||||
}
|
||||
if content_block["source_type"] == "base64":
|
||||
if "mime_type" not in content_block:
|
||||
error_message = "mime_type key is required for base64 data."
|
||||
raise ValueError(error_message)
|
||||
mime_type = content_block["mime_type"]
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{mime_type};base64,{content_block['source']}",
|
||||
},
|
||||
}
|
||||
error_message = "Unsupported source type. Only 'url' and 'base64' are supported."
|
||||
raise ValueError(error_message)
|
||||
|
@ -10,6 +10,7 @@ EXPECTED_ALL = [
|
||||
"BaseMessageChunk",
|
||||
"ChatMessage",
|
||||
"ChatMessageChunk",
|
||||
"DataContentBlock",
|
||||
"FunctionMessage",
|
||||
"FunctionMessageChunk",
|
||||
"HumanMessage",
|
||||
@ -24,6 +25,7 @@ EXPECTED_ALL = [
|
||||
"RemoveMessage",
|
||||
"convert_to_messages",
|
||||
"get_buffer_string",
|
||||
"is_data_content_block",
|
||||
"merge_content",
|
||||
"message_chunk_to_message",
|
||||
"message_to_dict",
|
||||
@ -32,6 +34,7 @@ EXPECTED_ALL = [
|
||||
"filter_messages",
|
||||
"merge_message_runs",
|
||||
"trim_messages",
|
||||
"convert_image_content_block_to_image_url",
|
||||
"convert_to_openai_messages",
|
||||
]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user