diff --git a/libs/core/langchain_core/language_models/_utils.py b/libs/core/langchain_core/language_models/_utils.py index bc2285f9201..f54d69728dd 100644 --- a/libs/core/langchain_core/language_models/_utils.py +++ b/libs/core/langchain_core/language_models/_utils.py @@ -1,4 +1,5 @@ import re +from collections.abc import Sequence from typing import Optional from langchain_core.messages import BaseMessage @@ -7,24 +8,30 @@ from langchain_core.messages import BaseMessage def _is_openai_data_block(block: dict) -> bool: """Check if the block contains multimodal data in OpenAI Chat Completions format.""" if block.get("type") == "image_url": - url = block.get("image_url", {}).get("url") - if isinstance(url, str) and set(block.keys()) <= { - "type", - "image_url", - "detail", - }: - return True + if ( + (set(block.keys()) <= {"type", "image_url", "detail"}) + and (image_url := block.get("image_url")) + and isinstance(image_url, dict) + ): + url = image_url.get("url") + if isinstance(url, str): + return True elif block.get("type") == "file": - data = block.get("file", {}).get("file_data") - if isinstance(data, str): - return True + if (file := block.get("file")) and isinstance(file, dict): + file_data = file.get("file_data") + if isinstance(file_data, str): + return True - elif block.get("type") == "input_audio": - audio_data = block.get("input_audio", {}).get("data") - audio_format = block.get("input_audio", {}).get("format") - if isinstance(audio_data, str) and isinstance(audio_format, str): - return True + elif block.get("type") == "input_audio": # noqa: SIM102 + if (input_audio := block.get("input_audio")) and isinstance(input_audio, dict): + audio_data = input_audio.get("data") + audio_format = input_audio.get("format") + if isinstance(audio_data, str) and isinstance(audio_format, str): + return True + + else: + return False return False @@ -98,7 +105,7 @@ def _convert_openai_format_to_data_block(block: dict) -> dict: return block -def _normalize_messages(messages: list[BaseMessage]) -> list[BaseMessage]: +def _normalize_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]: """Extend support for message formats. Chat models implement support for images in OpenAI Chat Completions format, as well diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index 99c2a829b8d..00e7c8e9f26 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -13,6 +13,7 @@ from langchain_core.language_models import ( FakeListChatModel, ParrotFakeChatModel, ) +from langchain_core.language_models._utils import _normalize_messages from langchain_core.language_models.fake_chat_models import FakeListChatModelError from langchain_core.messages import ( AIMessage, @@ -567,3 +568,31 @@ def test_extend_support_to_openai_multimodal_formats() -> None: "input_audio": {"data": "", "format": "wav"}, }, ] + + +def test_normalize_messages_edge_cases() -> None: + # Test some blocks that should pass through + messages = [ + HumanMessage( + content=[ + { + "type": "file", + "file": "uri", + }, + { + "type": "input_file", + "file_data": "uri", + "filename": "file-name", + }, + { + "type": "input_audio", + "input_audio": "uri", + }, + { + "type": "input_image", + "image_url": "uri", + }, + ] + ) + ] + assert messages == _normalize_messages(messages)