1
0
mirror of https://github.com/hwchase17/langchain.git synced 2025-05-05 07:08:03 +00:00

core[patch]: fix edge cases for _is_openai_data_block ()

This commit is contained in:
ccurme 2025-04-24 10:48:52 -04:00 committed by GitHub
parent ae4b6380d9
commit f4863f82e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 52 additions and 16 deletions
libs/core
langchain_core/language_models
tests/unit_tests/language_models/chat_models

View File

@ -1,4 +1,5 @@
import re
from collections.abc import Sequence
from typing import Optional
from langchain_core.messages import BaseMessage
@ -7,24 +8,30 @@ from langchain_core.messages import BaseMessage
def _is_openai_data_block(block: dict) -> bool:
"""Check if the block contains multimodal data in OpenAI Chat Completions format."""
if block.get("type") == "image_url":
url = block.get("image_url", {}).get("url")
if isinstance(url, str) and set(block.keys()) <= {
"type",
"image_url",
"detail",
}:
return True
if (
(set(block.keys()) <= {"type", "image_url", "detail"})
and (image_url := block.get("image_url"))
and isinstance(image_url, dict)
):
url = image_url.get("url")
if isinstance(url, str):
return True
elif block.get("type") == "file":
data = block.get("file", {}).get("file_data")
if isinstance(data, str):
return True
if (file := block.get("file")) and isinstance(file, dict):
file_data = file.get("file_data")
if isinstance(file_data, str):
return True
elif block.get("type") == "input_audio":
audio_data = block.get("input_audio", {}).get("data")
audio_format = block.get("input_audio", {}).get("format")
if isinstance(audio_data, str) and isinstance(audio_format, str):
return True
elif block.get("type") == "input_audio": # noqa: SIM102
if (input_audio := block.get("input_audio")) and isinstance(input_audio, dict):
audio_data = input_audio.get("data")
audio_format = input_audio.get("format")
if isinstance(audio_data, str) and isinstance(audio_format, str):
return True
else:
return False
return False
@ -98,7 +105,7 @@ def _convert_openai_format_to_data_block(block: dict) -> dict:
return block
def _normalize_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
def _normalize_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]:
"""Extend support for message formats.
Chat models implement support for images in OpenAI Chat Completions format, as well

View File

@ -13,6 +13,7 @@ from langchain_core.language_models import (
FakeListChatModel,
ParrotFakeChatModel,
)
from langchain_core.language_models._utils import _normalize_messages
from langchain_core.language_models.fake_chat_models import FakeListChatModelError
from langchain_core.messages import (
AIMessage,
@ -567,3 +568,31 @@ def test_extend_support_to_openai_multimodal_formats() -> None:
"input_audio": {"data": "<base64 data>", "format": "wav"},
},
]
def test_normalize_messages_edge_cases() -> None:
# Test some blocks that should pass through
messages = [
HumanMessage(
content=[
{
"type": "file",
"file": "uri",
},
{
"type": "input_file",
"file_data": "uri",
"filename": "file-name",
},
{
"type": "input_audio",
"input_audio": "uri",
},
{
"type": "input_image",
"image_url": "uri",
},
]
)
]
assert messages == _normalize_messages(messages)