core[patch]: fix edge cases for _is_openai_data_block (#30997)

This commit is contained in:
ccurme 2025-04-24 10:48:52 -04:00 committed by GitHub
parent ae4b6380d9
commit f4863f82e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 52 additions and 16 deletions

View File

@ -1,4 +1,5 @@
import re import re
from collections.abc import Sequence
from typing import Optional from typing import Optional
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
@ -7,24 +8,30 @@ from langchain_core.messages import BaseMessage
def _is_openai_data_block(block: dict) -> bool: def _is_openai_data_block(block: dict) -> bool:
"""Check if the block contains multimodal data in OpenAI Chat Completions format.""" """Check if the block contains multimodal data in OpenAI Chat Completions format."""
if block.get("type") == "image_url": if block.get("type") == "image_url":
url = block.get("image_url", {}).get("url") if (
if isinstance(url, str) and set(block.keys()) <= { (set(block.keys()) <= {"type", "image_url", "detail"})
"type", and (image_url := block.get("image_url"))
"image_url", and isinstance(image_url, dict)
"detail", ):
}: url = image_url.get("url")
return True if isinstance(url, str):
return True
elif block.get("type") == "file": elif block.get("type") == "file":
data = block.get("file", {}).get("file_data") if (file := block.get("file")) and isinstance(file, dict):
if isinstance(data, str): file_data = file.get("file_data")
return True if isinstance(file_data, str):
return True
elif block.get("type") == "input_audio": elif block.get("type") == "input_audio": # noqa: SIM102
audio_data = block.get("input_audio", {}).get("data") if (input_audio := block.get("input_audio")) and isinstance(input_audio, dict):
audio_format = block.get("input_audio", {}).get("format") audio_data = input_audio.get("data")
if isinstance(audio_data, str) and isinstance(audio_format, str): audio_format = input_audio.get("format")
return True if isinstance(audio_data, str) and isinstance(audio_format, str):
return True
else:
return False
return False return False
@ -98,7 +105,7 @@ def _convert_openai_format_to_data_block(block: dict) -> dict:
return block return block
def _normalize_messages(messages: list[BaseMessage]) -> list[BaseMessage]: def _normalize_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]:
"""Extend support for message formats. """Extend support for message formats.
Chat models implement support for images in OpenAI Chat Completions format, as well Chat models implement support for images in OpenAI Chat Completions format, as well

View File

@ -13,6 +13,7 @@ from langchain_core.language_models import (
FakeListChatModel, FakeListChatModel,
ParrotFakeChatModel, ParrotFakeChatModel,
) )
from langchain_core.language_models._utils import _normalize_messages
from langchain_core.language_models.fake_chat_models import FakeListChatModelError from langchain_core.language_models.fake_chat_models import FakeListChatModelError
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
@ -567,3 +568,31 @@ def test_extend_support_to_openai_multimodal_formats() -> None:
"input_audio": {"data": "<base64 data>", "format": "wav"}, "input_audio": {"data": "<base64 data>", "format": "wav"},
}, },
] ]
def test_normalize_messages_edge_cases() -> None:
# Test some blocks that should pass through
messages = [
HumanMessage(
content=[
{
"type": "file",
"file": "uri",
},
{
"type": "input_file",
"file_data": "uri",
"filename": "file-name",
},
{
"type": "input_audio",
"input_audio": "uri",
},
{
"type": "input_image",
"image_url": "uri",
},
]
)
]
assert messages == _normalize_messages(messages)