core, standard-tests: support PDF and audio input in Chat Completions format (#30979)

Chat models currently implement support for:
- images in OpenAI Chat Completions format
- other multimodal types (e.g., PDF and audio) in a cross-provider
[standard
format](https://python.langchain.com/docs/how_to/multimodal_inputs/)

Here we update core to extend support to PDF and audio input in Chat
Completions format. **If an OAI-format PDF or audio content block is
passed into any chat model, it will be transformed to the LangChain
standard format**. We assume that any chat model supporting OAI-format
PDF or audio has implemented support for the standard format.
This commit is contained in:
ccurme 2025-04-23 14:32:51 -04:00 committed by GitHub
parent d4fc734250
commit faef3e5d50
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 305 additions and 4 deletions

View File

@ -0,0 +1,132 @@
import re
from typing import Optional
from langchain_core.messages import BaseMessage
def _is_openai_data_block(block: dict) -> bool:
"""Check if the block contains multimodal data in OpenAI Chat Completions format."""
if block.get("type") == "image_url":
url = block.get("image_url", {}).get("url")
if isinstance(url, str) and set(block.keys()) <= {
"type",
"image_url",
"detail",
}:
return True
elif block.get("type") == "file":
data = block.get("file", {}).get("file_data")
if isinstance(data, str):
return True
elif block.get("type") == "input_audio":
audio_data = block.get("input_audio", {}).get("data")
audio_format = block.get("input_audio", {}).get("format")
if isinstance(audio_data, str) and isinstance(audio_format, str):
return True
return False
def _parse_data_uri(uri: str) -> Optional[dict]:
"""Parse a data URI into its components. If parsing fails, return None.
Example:
.. code-block:: python
data_uri = "..."
parsed = _parse_data_uri(data_uri)
assert parsed == {
"source_type": "base64",
"mime_type": "image/jpeg",
"data": "/9j/4AAQSkZJRg...",
}
"""
regex = r"^data:(?P<mime_type>[^;]+);base64,(?P<data>.+)$"
match = re.match(regex, uri)
if match is None:
return None
return {
"source_type": "base64",
"data": match.group("data"),
"mime_type": match.group("mime_type"),
}
def _convert_openai_format_to_data_block(block: dict) -> dict:
"""Convert OpenAI image content block to standard data content block.
If parsing fails, pass-through.
Args:
block: The OpenAI image content block to convert.
Returns:
The converted standard data content block.
"""
if block["type"] == "image_url":
parsed = _parse_data_uri(block["image_url"]["url"])
if parsed is not None:
parsed["type"] = "image"
return parsed
return block
if block["type"] == "file":
parsed = _parse_data_uri(block["file"]["file_data"])
if parsed is not None:
parsed["type"] = "file"
if filename := block["file"].get("filename"):
parsed["filename"] = filename
return parsed
return block
if block["type"] == "input_audio":
data = block["input_audio"].get("data")
format = block["input_audio"].get("format")
if data and format:
return {
"type": "audio",
"source_type": "base64",
"data": data,
"mime_type": f"audio/{format}",
}
return block
return block
def _normalize_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
"""Extend support for message formats.
Chat models implement support for images in OpenAI Chat Completions format, as well
as other multimodal data as standard data blocks. This function extends support to
audio and file data in OpenAI Chat Completions format by converting them to standard
data blocks.
"""
formatted_messages = []
for message in messages:
formatted_message = message
if isinstance(message.content, list):
for idx, block in enumerate(message.content):
if (
isinstance(block, dict)
# Subset to (PDF) files and audio, as most relevant chat models
# support images in OAI format (and some may not yet support the
# standard data block format)
and block.get("type") in ("file", "input_audio")
and _is_openai_data_block(block)
):
if formatted_message is message:
formatted_message = message.model_copy()
# Also shallow-copy content
formatted_message.content = list(formatted_message.content)
formatted_message.content[idx] = ( # type: ignore[index] # mypy confused by .model_copy
_convert_openai_format_to_data_block(block)
)
formatted_messages.append(formatted_message)
return formatted_messages

View File

@ -40,6 +40,7 @@ from langchain_core.callbacks import (
Callbacks,
)
from langchain_core.globals import get_llm_cache
from langchain_core.language_models._utils import _normalize_messages
from langchain_core.language_models.base import (
BaseLanguageModel,
LangSmithParams,
@ -489,7 +490,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
self.rate_limiter.acquire(blocking=True)
try:
for chunk in self._stream(messages, stop=stop, **kwargs):
input_messages = _normalize_messages(messages)
for chunk in self._stream(input_messages, stop=stop, **kwargs):
if chunk.message.id is None:
chunk.message.id = f"run-{run_manager.run_id}"
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
@ -574,8 +576,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
generation: Optional[ChatGenerationChunk] = None
try:
input_messages = _normalize_messages(messages)
async for chunk in self._astream(
messages,
input_messages,
stop=stop,
**kwargs,
):
@ -753,7 +756,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
batch_size=len(messages),
)
results = []
for i, m in enumerate(messages):
input_messages = [
_normalize_messages(message_list) for message_list in messages
]
for i, m in enumerate(input_messages):
try:
results.append(
self._generate_with_cache(
@ -865,6 +871,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_id=run_id,
)
input_messages = [
_normalize_messages(message_list) for message_list in messages
]
results = await asyncio.gather(
*[
self._agenerate_with_cache(
@ -873,7 +882,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_manager=run_managers[i] if run_managers else None,
**kwargs,
)
for i, m in enumerate(messages)
for i, m in enumerate(input_messages)
],
return_exceptions=True,
)

View File

@ -455,3 +455,115 @@ def test_trace_images_in_openai_format() -> None:
"url": "https://example.com/image.png",
}
]
def test_extend_support_to_openai_multimodal_formats() -> None:
"""Test that chat models normalize OpenAI file and audio inputs."""
llm = ParrotFakeChatModel()
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Hello"},
{
"type": "image_url",
"image_url": {"url": "https://example.com/image.png"},
},
{
"type": "image_url",
"image_url": {"url": "..."},
},
{
"type": "file",
"file": {
"filename": "draconomicon.pdf",
"file_data": "data:application/pdf;base64,<base64 string>",
},
},
{
"type": "file",
"file": {
"file_data": "data:application/pdf;base64,<base64 string>",
},
},
{
"type": "file",
"file": {"file_id": "<file id>"},
},
{
"type": "input_audio",
"input_audio": {"data": "<base64 data>", "format": "wav"},
},
],
},
]
expected_content = [
{"type": "text", "text": "Hello"},
{
"type": "image_url",
"image_url": {"url": "https://example.com/image.png"},
},
{
"type": "image_url",
"image_url": {"url": "..."},
},
{
"type": "file",
"source_type": "base64",
"data": "<base64 string>",
"mime_type": "application/pdf",
"filename": "draconomicon.pdf",
},
{
"type": "file",
"source_type": "base64",
"data": "<base64 string>",
"mime_type": "application/pdf",
},
{
"type": "file",
"file": {"file_id": "<file id>"},
},
{
"type": "audio",
"source_type": "base64",
"data": "<base64 data>",
"mime_type": "audio/wav",
},
]
response = llm.invoke(messages)
assert response.content == expected_content
# Test no mutation
assert messages[0]["content"] == [
{"type": "text", "text": "Hello"},
{
"type": "image_url",
"image_url": {"url": "https://example.com/image.png"},
},
{
"type": "image_url",
"image_url": {"url": "..."},
},
{
"type": "file",
"file": {
"filename": "draconomicon.pdf",
"file_data": "data:application/pdf;base64,<base64 string>",
},
},
{
"type": "file",
"file": {
"file_data": "data:application/pdf;base64,<base64 string>",
},
},
{
"type": "file",
"file": {"file_id": "<file id>"},
},
{
"type": "input_audio",
"input_audio": {"data": "<base64 data>", "format": "wav"},
},
]

View File

@ -103,6 +103,21 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
)
_ = model.invoke([message])
# Test OpenAI Chat Completions format
message = HumanMessage(
[
{"type": "text", "text": "Summarize this document:"},
{
"type": "file",
"file": {
"filename": "test file.pdf",
"file_data": f"data:application/pdf;base64,{pdf_data}",
},
},
]
)
_ = model.invoke([message])
def _invoke(llm: ChatOpenAI, input_: str, stream: bool) -> AIMessage:
if stream:

View File

@ -2036,6 +2036,24 @@ class ChatModelIntegrationTests(ChatModelTests):
)
_ = model.invoke([message])
# Test OpenAI Chat Completions format
message = HumanMessage(
[
{
"type": "text",
"text": "Summarize this document:",
},
{
"type": "file",
"file": {
"filename": "test file.pdf",
"file_data": f"data:application/pdf;base64,{pdf_data}",
},
},
]
)
_ = model.invoke([message])
def test_audio_inputs(self, model: BaseChatModel) -> None:
"""Test that the model can process audio inputs.
@ -2093,6 +2111,21 @@ class ChatModelIntegrationTests(ChatModelTests):
)
_ = model.invoke([message])
# Test OpenAI Chat Completions format
message = HumanMessage(
[
{
"type": "text",
"text": "Describe this audio:",
},
{
"type": "input_audio",
"input_audio": {"data": audio_data, "format": "wav"},
},
]
)
_ = model.invoke([message])
def test_image_inputs(self, model: BaseChatModel) -> None:
"""Test that the model can process image inputs.