mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-27 03:31:51 +00:00
core, standard-tests: support PDF and audio input in Chat Completions format (#30979)
Chat models currently implement support for: - images in OpenAI Chat Completions format - other multimodal types (e.g., PDF and audio) in a cross-provider [standard format](https://python.langchain.com/docs/how_to/multimodal_inputs/) Here we update core to extend support to PDF and audio input in Chat Completions format. **If an OAI-format PDF or audio content block is passed into any chat model, it will be transformed to the LangChain standard format**. We assume that any chat model supporting OAI-format PDF or audio has implemented support for the standard format.
This commit is contained in:
parent
d4fc734250
commit
faef3e5d50
132
libs/core/langchain_core/language_models/_utils.py
Normal file
132
libs/core/langchain_core/language_models/_utils.py
Normal file
@ -0,0 +1,132 @@
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
|
||||
def _is_openai_data_block(block: dict) -> bool:
|
||||
"""Check if the block contains multimodal data in OpenAI Chat Completions format."""
|
||||
if block.get("type") == "image_url":
|
||||
url = block.get("image_url", {}).get("url")
|
||||
if isinstance(url, str) and set(block.keys()) <= {
|
||||
"type",
|
||||
"image_url",
|
||||
"detail",
|
||||
}:
|
||||
return True
|
||||
|
||||
elif block.get("type") == "file":
|
||||
data = block.get("file", {}).get("file_data")
|
||||
if isinstance(data, str):
|
||||
return True
|
||||
|
||||
elif block.get("type") == "input_audio":
|
||||
audio_data = block.get("input_audio", {}).get("data")
|
||||
audio_format = block.get("input_audio", {}).get("format")
|
||||
if isinstance(audio_data, str) and isinstance(audio_format, str):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _parse_data_uri(uri: str) -> Optional[dict]:
|
||||
"""Parse a data URI into its components. If parsing fails, return None.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
data_uri = "..."
|
||||
parsed = _parse_data_uri(data_uri)
|
||||
|
||||
assert parsed == {
|
||||
"source_type": "base64",
|
||||
"mime_type": "image/jpeg",
|
||||
"data": "/9j/4AAQSkZJRg...",
|
||||
}
|
||||
"""
|
||||
regex = r"^data:(?P<mime_type>[^;]+);base64,(?P<data>.+)$"
|
||||
match = re.match(regex, uri)
|
||||
if match is None:
|
||||
return None
|
||||
return {
|
||||
"source_type": "base64",
|
||||
"data": match.group("data"),
|
||||
"mime_type": match.group("mime_type"),
|
||||
}
|
||||
|
||||
|
||||
def _convert_openai_format_to_data_block(block: dict) -> dict:
|
||||
"""Convert OpenAI image content block to standard data content block.
|
||||
|
||||
If parsing fails, pass-through.
|
||||
|
||||
Args:
|
||||
block: The OpenAI image content block to convert.
|
||||
|
||||
Returns:
|
||||
The converted standard data content block.
|
||||
"""
|
||||
if block["type"] == "image_url":
|
||||
parsed = _parse_data_uri(block["image_url"]["url"])
|
||||
if parsed is not None:
|
||||
parsed["type"] = "image"
|
||||
return parsed
|
||||
return block
|
||||
|
||||
if block["type"] == "file":
|
||||
parsed = _parse_data_uri(block["file"]["file_data"])
|
||||
if parsed is not None:
|
||||
parsed["type"] = "file"
|
||||
if filename := block["file"].get("filename"):
|
||||
parsed["filename"] = filename
|
||||
return parsed
|
||||
return block
|
||||
|
||||
if block["type"] == "input_audio":
|
||||
data = block["input_audio"].get("data")
|
||||
format = block["input_audio"].get("format")
|
||||
if data and format:
|
||||
return {
|
||||
"type": "audio",
|
||||
"source_type": "base64",
|
||||
"data": data,
|
||||
"mime_type": f"audio/{format}",
|
||||
}
|
||||
return block
|
||||
|
||||
return block
|
||||
|
||||
|
||||
def _normalize_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||
"""Extend support for message formats.
|
||||
|
||||
Chat models implement support for images in OpenAI Chat Completions format, as well
|
||||
as other multimodal data as standard data blocks. This function extends support to
|
||||
audio and file data in OpenAI Chat Completions format by converting them to standard
|
||||
data blocks.
|
||||
"""
|
||||
formatted_messages = []
|
||||
for message in messages:
|
||||
formatted_message = message
|
||||
if isinstance(message.content, list):
|
||||
for idx, block in enumerate(message.content):
|
||||
if (
|
||||
isinstance(block, dict)
|
||||
# Subset to (PDF) files and audio, as most relevant chat models
|
||||
# support images in OAI format (and some may not yet support the
|
||||
# standard data block format)
|
||||
and block.get("type") in ("file", "input_audio")
|
||||
and _is_openai_data_block(block)
|
||||
):
|
||||
if formatted_message is message:
|
||||
formatted_message = message.model_copy()
|
||||
# Also shallow-copy content
|
||||
formatted_message.content = list(formatted_message.content)
|
||||
|
||||
formatted_message.content[idx] = ( # type: ignore[index] # mypy confused by .model_copy
|
||||
_convert_openai_format_to_data_block(block)
|
||||
)
|
||||
formatted_messages.append(formatted_message)
|
||||
|
||||
return formatted_messages
|
@ -40,6 +40,7 @@ from langchain_core.callbacks import (
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.globals import get_llm_cache
|
||||
from langchain_core.language_models._utils import _normalize_messages
|
||||
from langchain_core.language_models.base import (
|
||||
BaseLanguageModel,
|
||||
LangSmithParams,
|
||||
@ -489,7 +490,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
self.rate_limiter.acquire(blocking=True)
|
||||
|
||||
try:
|
||||
for chunk in self._stream(messages, stop=stop, **kwargs):
|
||||
input_messages = _normalize_messages(messages)
|
||||
for chunk in self._stream(input_messages, stop=stop, **kwargs):
|
||||
if chunk.message.id is None:
|
||||
chunk.message.id = f"run-{run_manager.run_id}"
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
@ -574,8 +576,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
try:
|
||||
input_messages = _normalize_messages(messages)
|
||||
async for chunk in self._astream(
|
||||
messages,
|
||||
input_messages,
|
||||
stop=stop,
|
||||
**kwargs,
|
||||
):
|
||||
@ -753,7 +756,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
batch_size=len(messages),
|
||||
)
|
||||
results = []
|
||||
for i, m in enumerate(messages):
|
||||
input_messages = [
|
||||
_normalize_messages(message_list) for message_list in messages
|
||||
]
|
||||
for i, m in enumerate(input_messages):
|
||||
try:
|
||||
results.append(
|
||||
self._generate_with_cache(
|
||||
@ -865,6 +871,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
input_messages = [
|
||||
_normalize_messages(message_list) for message_list in messages
|
||||
]
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
self._agenerate_with_cache(
|
||||
@ -873,7 +882,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
run_manager=run_managers[i] if run_managers else None,
|
||||
**kwargs,
|
||||
)
|
||||
for i, m in enumerate(messages)
|
||||
for i, m in enumerate(input_messages)
|
||||
],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
@ -455,3 +455,115 @@ def test_trace_images_in_openai_format() -> None:
|
||||
"url": "https://example.com/image.png",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_extend_support_to_openai_multimodal_formats() -> None:
|
||||
"""Test that chat models normalize OpenAI file and audio inputs."""
|
||||
llm = ParrotFakeChatModel()
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Hello"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://example.com/image.png"},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "..."},
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"file": {
|
||||
"filename": "draconomicon.pdf",
|
||||
"file_data": "data:application/pdf;base64,<base64 string>",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"file": {
|
||||
"file_data": "data:application/pdf;base64,<base64 string>",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"file": {"file_id": "<file id>"},
|
||||
},
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {"data": "<base64 data>", "format": "wav"},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
expected_content = [
|
||||
{"type": "text", "text": "Hello"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://example.com/image.png"},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "..."},
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"source_type": "base64",
|
||||
"data": "<base64 string>",
|
||||
"mime_type": "application/pdf",
|
||||
"filename": "draconomicon.pdf",
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"source_type": "base64",
|
||||
"data": "<base64 string>",
|
||||
"mime_type": "application/pdf",
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"file": {"file_id": "<file id>"},
|
||||
},
|
||||
{
|
||||
"type": "audio",
|
||||
"source_type": "base64",
|
||||
"data": "<base64 data>",
|
||||
"mime_type": "audio/wav",
|
||||
},
|
||||
]
|
||||
response = llm.invoke(messages)
|
||||
assert response.content == expected_content
|
||||
|
||||
# Test no mutation
|
||||
assert messages[0]["content"] == [
|
||||
{"type": "text", "text": "Hello"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://example.com/image.png"},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "..."},
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"file": {
|
||||
"filename": "draconomicon.pdf",
|
||||
"file_data": "data:application/pdf;base64,<base64 string>",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"file": {
|
||||
"file_data": "data:application/pdf;base64,<base64 string>",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"file": {"file_id": "<file id>"},
|
||||
},
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {"data": "<base64 data>", "format": "wav"},
|
||||
},
|
||||
]
|
||||
|
@ -103,6 +103,21 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
|
||||
)
|
||||
_ = model.invoke([message])
|
||||
|
||||
# Test OpenAI Chat Completions format
|
||||
message = HumanMessage(
|
||||
[
|
||||
{"type": "text", "text": "Summarize this document:"},
|
||||
{
|
||||
"type": "file",
|
||||
"file": {
|
||||
"filename": "test file.pdf",
|
||||
"file_data": f"data:application/pdf;base64,{pdf_data}",
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
_ = model.invoke([message])
|
||||
|
||||
|
||||
def _invoke(llm: ChatOpenAI, input_: str, stream: bool) -> AIMessage:
|
||||
if stream:
|
||||
|
@ -2036,6 +2036,24 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
)
|
||||
_ = model.invoke([message])
|
||||
|
||||
# Test OpenAI Chat Completions format
|
||||
message = HumanMessage(
|
||||
[
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Summarize this document:",
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"file": {
|
||||
"filename": "test file.pdf",
|
||||
"file_data": f"data:application/pdf;base64,{pdf_data}",
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
_ = model.invoke([message])
|
||||
|
||||
def test_audio_inputs(self, model: BaseChatModel) -> None:
|
||||
"""Test that the model can process audio inputs.
|
||||
|
||||
@ -2093,6 +2111,21 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
)
|
||||
_ = model.invoke([message])
|
||||
|
||||
# Test OpenAI Chat Completions format
|
||||
message = HumanMessage(
|
||||
[
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Describe this audio:",
|
||||
},
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {"data": audio_data, "format": "wav"},
|
||||
},
|
||||
]
|
||||
)
|
||||
_ = model.invoke([message])
|
||||
|
||||
def test_image_inputs(self, model: BaseChatModel) -> None:
|
||||
"""Test that the model can process image inputs.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user