mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-05 07:08:03 +00:00
core[patch]: fix edge cases for _is_openai_data_block (#30997)
This commit is contained in:
parent
ae4b6380d9
commit
f4863f82e2
libs/core
@ -1,4 +1,5 @@
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
@ -7,24 +8,30 @@ from langchain_core.messages import BaseMessage
|
||||
def _is_openai_data_block(block: dict) -> bool:
|
||||
"""Check if the block contains multimodal data in OpenAI Chat Completions format."""
|
||||
if block.get("type") == "image_url":
|
||||
url = block.get("image_url", {}).get("url")
|
||||
if isinstance(url, str) and set(block.keys()) <= {
|
||||
"type",
|
||||
"image_url",
|
||||
"detail",
|
||||
}:
|
||||
return True
|
||||
if (
|
||||
(set(block.keys()) <= {"type", "image_url", "detail"})
|
||||
and (image_url := block.get("image_url"))
|
||||
and isinstance(image_url, dict)
|
||||
):
|
||||
url = image_url.get("url")
|
||||
if isinstance(url, str):
|
||||
return True
|
||||
|
||||
elif block.get("type") == "file":
|
||||
data = block.get("file", {}).get("file_data")
|
||||
if isinstance(data, str):
|
||||
return True
|
||||
if (file := block.get("file")) and isinstance(file, dict):
|
||||
file_data = file.get("file_data")
|
||||
if isinstance(file_data, str):
|
||||
return True
|
||||
|
||||
elif block.get("type") == "input_audio":
|
||||
audio_data = block.get("input_audio", {}).get("data")
|
||||
audio_format = block.get("input_audio", {}).get("format")
|
||||
if isinstance(audio_data, str) and isinstance(audio_format, str):
|
||||
return True
|
||||
elif block.get("type") == "input_audio": # noqa: SIM102
|
||||
if (input_audio := block.get("input_audio")) and isinstance(input_audio, dict):
|
||||
audio_data = input_audio.get("data")
|
||||
audio_format = input_audio.get("format")
|
||||
if isinstance(audio_data, str) and isinstance(audio_format, str):
|
||||
return True
|
||||
|
||||
else:
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
@ -98,7 +105,7 @@ def _convert_openai_format_to_data_block(block: dict) -> dict:
|
||||
return block
|
||||
|
||||
|
||||
def _normalize_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||
def _normalize_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]:
|
||||
"""Extend support for message formats.
|
||||
|
||||
Chat models implement support for images in OpenAI Chat Completions format, as well
|
||||
|
@ -13,6 +13,7 @@ from langchain_core.language_models import (
|
||||
FakeListChatModel,
|
||||
ParrotFakeChatModel,
|
||||
)
|
||||
from langchain_core.language_models._utils import _normalize_messages
|
||||
from langchain_core.language_models.fake_chat_models import FakeListChatModelError
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@ -567,3 +568,31 @@ def test_extend_support_to_openai_multimodal_formats() -> None:
|
||||
"input_audio": {"data": "<base64 data>", "format": "wav"},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_normalize_messages_edge_cases() -> None:
|
||||
# Test some blocks that should pass through
|
||||
messages = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "file",
|
||||
"file": "uri",
|
||||
},
|
||||
{
|
||||
"type": "input_file",
|
||||
"file_data": "uri",
|
||||
"filename": "file-name",
|
||||
},
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": "uri",
|
||||
},
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": "uri",
|
||||
},
|
||||
]
|
||||
)
|
||||
]
|
||||
assert messages == _normalize_messages(messages)
|
||||
|
Loading…
Reference in New Issue
Block a user