diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index fb9bce64c37..809c1a6ae6f 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -44,6 +44,10 @@ from langchain_core.messages import ( SystemMessageChunk, ToolCall, ToolMessage, + is_data_content_block, +) +from langchain_core.messages.block_translators.openai import ( + convert_to_openai_data_block, ) from langchain_core.messages.tool import tool_call_chunk from langchain_core.output_parsers import ( @@ -369,13 +373,46 @@ def _clean_block(block: dict) -> dict: return new_block +def _format_message_content(content: Any) -> Any: + """Format message content for the Mistral chat completions wire format. + + Walks list content and translates LangChain canonical v0/v1 multimodal + data blocks (e.g. `ImageContentBlock` with `url`, `base64`, or + `file_id`) into the OpenAI-compatible shape that Mistral accepts: + `{"type": "image_url", "image_url": {"url": "..."}}`. Strings and any + other dict blocks are returned unchanged so that already-translated wire + blocks (e.g. `text`, `image_url`) and Mistral-specific blocks + (`document_url`, `input_audio`) pass through; the API surfaces an error + for anything it doesn't understand. + + Args: + content: The message content. Strings and non-list values pass + through unchanged; lists are walked block by block. + + Returns: + The formatted content. List inputs return a new list with canonical + data-block translations applied; other inputs are returned as-is. + """ + if not isinstance(content, list): + return content + formatted: list[Any] = [] + for block in content: + if isinstance(block, dict) and is_data_content_block(block): + formatted.append( + convert_to_openai_data_block(block, api="chat/completions") + ) + continue + formatted.append(block) + return formatted + + def _convert_message_to_mistral_chat_message( message: BaseMessage, ) -> dict: if isinstance(message, ChatMessage): return {"role": message.role, "content": message.content} if isinstance(message, HumanMessage): - return {"role": "user", "content": message.content} + return {"role": "user", "content": _format_message_content(message.content)} if isinstance(message, AIMessage): message_dict: dict[str, Any] = {"role": "assistant"} tool_calls: list = [] diff --git a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py index 76a8dfac39e..b4cc3fd1531 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py @@ -24,6 +24,7 @@ from langchain_mistralai.chat_models import ( # type: ignore[import] _convert_message_to_mistral_chat_message, _convert_mistral_chat_message_to_message, _convert_tool_call_id_to_mistral_compatible, + _format_message_content, _is_valid_mistral_tool_call_id, ) @@ -111,6 +112,180 @@ def test_convert_message_to_mistral_chat_message( assert result == expected +@pytest.mark.parametrize( + ("content", "expected"), + [ + ("hello", "hello"), + ("", ""), + (None, None), + ([], []), + ], +) +def test_format_message_content_passthrough_non_list( + content: Any, expected: Any +) -> None: + """Strings, None, and empty lists pass through `_format_message_content`.""" + assert _format_message_content(content) == expected + + +@pytest.mark.parametrize( + ("block", "expected"), + [ + ( + {"type": "image", "url": "https://example.com/img.png"}, + { + "type": "image_url", + "image_url": {"url": "https://example.com/img.png"}, + }, + ), + ( + {"type": "image", "base64": "abc123", "mime_type": "image/jpeg"}, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,abc123"}, + }, + ), + ( + { + "type": "image", + "source_type": "url", + "url": "https://example.com/v0.png", + }, + { + "type": "image_url", + "image_url": {"url": "https://example.com/v0.png"}, + }, + ), + ( + { + "type": "image", + "source_type": "base64", + "data": "v0data", + "mime_type": "image/png", + }, + { + "type": "image_url", + "image_url": {"url": "data:image/png;base64,v0data"}, + }, + ), + ], +) +def test_format_message_content_translates_image_blocks( + block: dict, expected: dict +) -> None: + """v0 and v1 canonical image blocks translate to Mistral's `image_url` shape.""" + assert _format_message_content([block]) == [expected] + + +@pytest.mark.parametrize( + "block", + [ + {"type": "text", "text": "hello"}, + {"type": "image_url", "image_url": {"url": "https://example.com/img.png"}}, + {"type": "image_url", "image_url": "https://example.com/img.png"}, + ], +) +def test_format_message_content_passthrough_known_blocks(block: dict) -> None: + """Already-translated wire blocks and text blocks pass through unchanged.""" + assert _format_message_content([block]) == [block] + + +@pytest.mark.parametrize( + "block_type", + ["tool_use", "thinking", "reasoning_content", "document_url", "input_audio"], +) +def test_format_message_content_passes_unknown_blocks_through(block_type: str) -> None: + """Non-canonical blocks pass through; the Mistral API validates them.""" + blocks = [ + {"type": "text", "text": "kept"}, + {"type": block_type, "data": "anything"}, + ] + assert _format_message_content(blocks) == blocks + + +def test_format_message_content_preserves_order_for_mixed_blocks() -> None: + """Multiple text + image blocks retain their order — vision prompts depend on it.""" + blocks: list[Any] = [ + {"type": "text", "text": "first"}, + {"type": "image", "url": "https://example.com/a.png"}, + {"type": "text", "text": "between"}, + {"type": "image", "base64": "xyz", "mime_type": "image/png"}, + "trailing string", + ] + expected = [ + {"type": "text", "text": "first"}, + {"type": "image_url", "image_url": {"url": "https://example.com/a.png"}}, + {"type": "text", "text": "between"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,xyz"}}, + "trailing string", + ] + assert _format_message_content(blocks) == expected + + +def test_format_message_content_image_missing_mime_type_raises() -> None: + """Base64 image without `mime_type` raises via the core translator.""" + with pytest.raises(ValueError, match="mime_type"): + _format_message_content([{"type": "image", "base64": "abc"}]) + + +@pytest.mark.parametrize( + ("message", "expected"), + [ + ( + HumanMessage( + content=[ + {"type": "text", "text": "What is in this image?"}, + {"type": "image", "url": "https://example.com/img.png"}, + ] + ), + { + "role": "user", + "content": [ + {"type": "text", "text": "What is in this image?"}, + { + "type": "image_url", + "image_url": {"url": "https://example.com/img.png"}, + }, + ], + }, + ), + ( + HumanMessage( + content=[ + {"type": "text", "text": "Describe this image."}, + { + "type": "image", + "base64": "abc123", + "mime_type": "image/png", + }, + ] + ), + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image."}, + { + "type": "image_url", + "image_url": {"url": "data:image/png;base64,abc123"}, + }, + ], + }, + ), + ], +) +def test_convert_human_message_with_images( + message: BaseMessage, expected: dict +) -> None: + result = _convert_message_to_mistral_chat_message(message) + assert result == expected + + +def test_convert_human_message_with_string_content_unchanged() -> None: + """Plain string `HumanMessage` content is not wrapped or modified.""" + result = _convert_message_to_mistral_chat_message(HumanMessage(content="hi")) + assert result == {"role": "user", "content": "hi"} + + def _make_completion_response_from_token(token: str) -> dict: return { "id": "abc123",