From faef3e5d50ba5759df73e5303625f9fef10162cc Mon Sep 17 00:00:00 2001 From: ccurme Date: Wed, 23 Apr 2025 14:32:51 -0400 Subject: [PATCH] core, standard-tests: support PDF and audio input in Chat Completions format (#30979) Chat models currently implement support for: - images in OpenAI Chat Completions format - other multimodal types (e.g., PDF and audio) in a cross-provider [standard format](https://python.langchain.com/docs/how_to/multimodal_inputs/) Here we update core to extend support to PDF and audio input in Chat Completions format. **If an OAI-format PDF or audio content block is passed into any chat model, it will be transformed to the LangChain standard format**. We assume that any chat model supporting OAI-format PDF or audio has implemented support for the standard format. --- .../langchain_core/language_models/_utils.py | 132 ++++++++++++++++++ .../language_models/chat_models.py | 17 ++- .../language_models/chat_models/test_base.py | 112 +++++++++++++++ .../chat_models/test_base_standard.py | 15 ++ .../integration_tests/chat_models.py | 33 +++++ 5 files changed, 305 insertions(+), 4 deletions(-) create mode 100644 libs/core/langchain_core/language_models/_utils.py diff --git a/libs/core/langchain_core/language_models/_utils.py b/libs/core/langchain_core/language_models/_utils.py new file mode 100644 index 00000000000..bc2285f9201 --- /dev/null +++ b/libs/core/langchain_core/language_models/_utils.py @@ -0,0 +1,132 @@ +import re +from typing import Optional + +from langchain_core.messages import BaseMessage + + +def _is_openai_data_block(block: dict) -> bool: + """Check if the block contains multimodal data in OpenAI Chat Completions format.""" + if block.get("type") == "image_url": + url = block.get("image_url", {}).get("url") + if isinstance(url, str) and set(block.keys()) <= { + "type", + "image_url", + "detail", + }: + return True + + elif block.get("type") == "file": + data = block.get("file", {}).get("file_data") + if isinstance(data, str): + return True + + elif block.get("type") == "input_audio": + audio_data = block.get("input_audio", {}).get("data") + audio_format = block.get("input_audio", {}).get("format") + if isinstance(audio_data, str) and isinstance(audio_format, str): + return True + + return False + + +def _parse_data_uri(uri: str) -> Optional[dict]: + """Parse a data URI into its components. If parsing fails, return None. + + Example: + + .. code-block:: python + + data_uri = "data:image/jpeg;base64,/9j/4AAQSkZJRg..." + parsed = _parse_data_uri(data_uri) + + assert parsed == { + "source_type": "base64", + "mime_type": "image/jpeg", + "data": "/9j/4AAQSkZJRg...", + } + """ + regex = r"^data:(?P[^;]+);base64,(?P.+)$" + match = re.match(regex, uri) + if match is None: + return None + return { + "source_type": "base64", + "data": match.group("data"), + "mime_type": match.group("mime_type"), + } + + +def _convert_openai_format_to_data_block(block: dict) -> dict: + """Convert OpenAI image content block to standard data content block. + + If parsing fails, pass-through. + + Args: + block: The OpenAI image content block to convert. + + Returns: + The converted standard data content block. + """ + if block["type"] == "image_url": + parsed = _parse_data_uri(block["image_url"]["url"]) + if parsed is not None: + parsed["type"] = "image" + return parsed + return block + + if block["type"] == "file": + parsed = _parse_data_uri(block["file"]["file_data"]) + if parsed is not None: + parsed["type"] = "file" + if filename := block["file"].get("filename"): + parsed["filename"] = filename + return parsed + return block + + if block["type"] == "input_audio": + data = block["input_audio"].get("data") + format = block["input_audio"].get("format") + if data and format: + return { + "type": "audio", + "source_type": "base64", + "data": data, + "mime_type": f"audio/{format}", + } + return block + + return block + + +def _normalize_messages(messages: list[BaseMessage]) -> list[BaseMessage]: + """Extend support for message formats. + + Chat models implement support for images in OpenAI Chat Completions format, as well + as other multimodal data as standard data blocks. This function extends support to + audio and file data in OpenAI Chat Completions format by converting them to standard + data blocks. + """ + formatted_messages = [] + for message in messages: + formatted_message = message + if isinstance(message.content, list): + for idx, block in enumerate(message.content): + if ( + isinstance(block, dict) + # Subset to (PDF) files and audio, as most relevant chat models + # support images in OAI format (and some may not yet support the + # standard data block format) + and block.get("type") in ("file", "input_audio") + and _is_openai_data_block(block) + ): + if formatted_message is message: + formatted_message = message.model_copy() + # Also shallow-copy content + formatted_message.content = list(formatted_message.content) + + formatted_message.content[idx] = ( # type: ignore[index] # mypy confused by .model_copy + _convert_openai_format_to_data_block(block) + ) + formatted_messages.append(formatted_message) + + return formatted_messages diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 93fdc325898..f0493497c1f 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -40,6 +40,7 @@ from langchain_core.callbacks import ( Callbacks, ) from langchain_core.globals import get_llm_cache +from langchain_core.language_models._utils import _normalize_messages from langchain_core.language_models.base import ( BaseLanguageModel, LangSmithParams, @@ -489,7 +490,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): self.rate_limiter.acquire(blocking=True) try: - for chunk in self._stream(messages, stop=stop, **kwargs): + input_messages = _normalize_messages(messages) + for chunk in self._stream(input_messages, stop=stop, **kwargs): if chunk.message.id is None: chunk.message.id = f"run-{run_manager.run_id}" chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk) @@ -574,8 +576,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): generation: Optional[ChatGenerationChunk] = None try: + input_messages = _normalize_messages(messages) async for chunk in self._astream( - messages, + input_messages, stop=stop, **kwargs, ): @@ -753,7 +756,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): batch_size=len(messages), ) results = [] - for i, m in enumerate(messages): + input_messages = [ + _normalize_messages(message_list) for message_list in messages + ] + for i, m in enumerate(input_messages): try: results.append( self._generate_with_cache( @@ -865,6 +871,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): run_id=run_id, ) + input_messages = [ + _normalize_messages(message_list) for message_list in messages + ] results = await asyncio.gather( *[ self._agenerate_with_cache( @@ -873,7 +882,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): run_manager=run_managers[i] if run_managers else None, **kwargs, ) - for i, m in enumerate(messages) + for i, m in enumerate(input_messages) ], return_exceptions=True, ) diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index eb1a5960542..99c2a829b8d 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -455,3 +455,115 @@ def test_trace_images_in_openai_format() -> None: "url": "https://example.com/image.png", } ] + + +def test_extend_support_to_openai_multimodal_formats() -> None: + """Test that chat models normalize OpenAI file and audio inputs.""" + llm = ParrotFakeChatModel() + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello"}, + { + "type": "image_url", + "image_url": {"url": "https://example.com/image.png"}, + }, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,/9j/4AAQSkZJRg..."}, + }, + { + "type": "file", + "file": { + "filename": "draconomicon.pdf", + "file_data": "data:application/pdf;base64,", + }, + }, + { + "type": "file", + "file": { + "file_data": "data:application/pdf;base64,", + }, + }, + { + "type": "file", + "file": {"file_id": ""}, + }, + { + "type": "input_audio", + "input_audio": {"data": "", "format": "wav"}, + }, + ], + }, + ] + expected_content = [ + {"type": "text", "text": "Hello"}, + { + "type": "image_url", + "image_url": {"url": "https://example.com/image.png"}, + }, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,/9j/4AAQSkZJRg..."}, + }, + { + "type": "file", + "source_type": "base64", + "data": "", + "mime_type": "application/pdf", + "filename": "draconomicon.pdf", + }, + { + "type": "file", + "source_type": "base64", + "data": "", + "mime_type": "application/pdf", + }, + { + "type": "file", + "file": {"file_id": ""}, + }, + { + "type": "audio", + "source_type": "base64", + "data": "", + "mime_type": "audio/wav", + }, + ] + response = llm.invoke(messages) + assert response.content == expected_content + + # Test no mutation + assert messages[0]["content"] == [ + {"type": "text", "text": "Hello"}, + { + "type": "image_url", + "image_url": {"url": "https://example.com/image.png"}, + }, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,/9j/4AAQSkZJRg..."}, + }, + { + "type": "file", + "file": { + "filename": "draconomicon.pdf", + "file_data": "data:application/pdf;base64,", + }, + }, + { + "type": "file", + "file": { + "file_data": "data:application/pdf;base64,", + }, + }, + { + "type": "file", + "file": {"file_id": ""}, + }, + { + "type": "input_audio", + "input_audio": {"data": "", "format": "wav"}, + }, + ] diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base_standard.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base_standard.py index 14e8865d594..31f2cc70a92 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base_standard.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base_standard.py @@ -103,6 +103,21 @@ class TestOpenAIStandard(ChatModelIntegrationTests): ) _ = model.invoke([message]) + # Test OpenAI Chat Completions format + message = HumanMessage( + [ + {"type": "text", "text": "Summarize this document:"}, + { + "type": "file", + "file": { + "filename": "test file.pdf", + "file_data": f"data:application/pdf;base64,{pdf_data}", + }, + }, + ] + ) + _ = model.invoke([message]) + def _invoke(llm: ChatOpenAI, input_: str, stream: bool) -> AIMessage: if stream: diff --git a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py index c9414e5f52a..ea6a2bb7068 100644 --- a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py @@ -2036,6 +2036,24 @@ class ChatModelIntegrationTests(ChatModelTests): ) _ = model.invoke([message]) + # Test OpenAI Chat Completions format + message = HumanMessage( + [ + { + "type": "text", + "text": "Summarize this document:", + }, + { + "type": "file", + "file": { + "filename": "test file.pdf", + "file_data": f"data:application/pdf;base64,{pdf_data}", + }, + }, + ] + ) + _ = model.invoke([message]) + def test_audio_inputs(self, model: BaseChatModel) -> None: """Test that the model can process audio inputs. @@ -2093,6 +2111,21 @@ class ChatModelIntegrationTests(ChatModelTests): ) _ = model.invoke([message]) + # Test OpenAI Chat Completions format + message = HumanMessage( + [ + { + "type": "text", + "text": "Describe this audio:", + }, + { + "type": "input_audio", + "input_audio": {"data": audio_data, "format": "wav"}, + }, + ] + ) + _ = model.invoke([message]) + def test_image_inputs(self, model: BaseChatModel) -> None: """Test that the model can process image inputs.