mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
core[patch]: fix edge cases for _is_openai_data_block (#30997)
This commit is contained in:
parent
ae4b6380d9
commit
f4863f82e2
@ -1,4 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
|
from collections.abc import Sequence
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
@ -7,24 +8,30 @@ from langchain_core.messages import BaseMessage
|
|||||||
def _is_openai_data_block(block: dict) -> bool:
|
def _is_openai_data_block(block: dict) -> bool:
|
||||||
"""Check if the block contains multimodal data in OpenAI Chat Completions format."""
|
"""Check if the block contains multimodal data in OpenAI Chat Completions format."""
|
||||||
if block.get("type") == "image_url":
|
if block.get("type") == "image_url":
|
||||||
url = block.get("image_url", {}).get("url")
|
if (
|
||||||
if isinstance(url, str) and set(block.keys()) <= {
|
(set(block.keys()) <= {"type", "image_url", "detail"})
|
||||||
"type",
|
and (image_url := block.get("image_url"))
|
||||||
"image_url",
|
and isinstance(image_url, dict)
|
||||||
"detail",
|
):
|
||||||
}:
|
url = image_url.get("url")
|
||||||
return True
|
if isinstance(url, str):
|
||||||
|
return True
|
||||||
|
|
||||||
elif block.get("type") == "file":
|
elif block.get("type") == "file":
|
||||||
data = block.get("file", {}).get("file_data")
|
if (file := block.get("file")) and isinstance(file, dict):
|
||||||
if isinstance(data, str):
|
file_data = file.get("file_data")
|
||||||
return True
|
if isinstance(file_data, str):
|
||||||
|
return True
|
||||||
|
|
||||||
elif block.get("type") == "input_audio":
|
elif block.get("type") == "input_audio": # noqa: SIM102
|
||||||
audio_data = block.get("input_audio", {}).get("data")
|
if (input_audio := block.get("input_audio")) and isinstance(input_audio, dict):
|
||||||
audio_format = block.get("input_audio", {}).get("format")
|
audio_data = input_audio.get("data")
|
||||||
if isinstance(audio_data, str) and isinstance(audio_format, str):
|
audio_format = input_audio.get("format")
|
||||||
return True
|
if isinstance(audio_data, str) and isinstance(audio_format, str):
|
||||||
|
return True
|
||||||
|
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -98,7 +105,7 @@ def _convert_openai_format_to_data_block(block: dict) -> dict:
|
|||||||
return block
|
return block
|
||||||
|
|
||||||
|
|
||||||
def _normalize_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
|
def _normalize_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]:
|
||||||
"""Extend support for message formats.
|
"""Extend support for message formats.
|
||||||
|
|
||||||
Chat models implement support for images in OpenAI Chat Completions format, as well
|
Chat models implement support for images in OpenAI Chat Completions format, as well
|
||||||
|
@ -13,6 +13,7 @@ from langchain_core.language_models import (
|
|||||||
FakeListChatModel,
|
FakeListChatModel,
|
||||||
ParrotFakeChatModel,
|
ParrotFakeChatModel,
|
||||||
)
|
)
|
||||||
|
from langchain_core.language_models._utils import _normalize_messages
|
||||||
from langchain_core.language_models.fake_chat_models import FakeListChatModelError
|
from langchain_core.language_models.fake_chat_models import FakeListChatModelError
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
@ -567,3 +568,31 @@ def test_extend_support_to_openai_multimodal_formats() -> None:
|
|||||||
"input_audio": {"data": "<base64 data>", "format": "wav"},
|
"input_audio": {"data": "<base64 data>", "format": "wav"},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_messages_edge_cases() -> None:
|
||||||
|
# Test some blocks that should pass through
|
||||||
|
messages = [
|
||||||
|
HumanMessage(
|
||||||
|
content=[
|
||||||
|
{
|
||||||
|
"type": "file",
|
||||||
|
"file": "uri",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "input_file",
|
||||||
|
"file_data": "uri",
|
||||||
|
"filename": "file-name",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "input_audio",
|
||||||
|
"input_audio": "uri",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "input_image",
|
||||||
|
"image_url": "uri",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
assert messages == _normalize_messages(messages)
|
||||||
|
Loading…
Reference in New Issue
Block a user