diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 809c1a6ae6f..e920fc6784c 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -373,6 +373,25 @@ def _clean_block(block: dict) -> dict: return new_block +def _sanitize_chat_completions_content(content: Any) -> Any: + """Strip non-wire keys from text content blocks. + + Mistral's chat completions endpoint rejects unknown fields on tool + message content blocks (e.g. the `id` that LangChain auto-generates on + `TextContentBlock`). For list content, keep only `type` and `text` on + text blocks; pass other blocks and non-list content through unchanged. + """ + if not isinstance(content, list): + return content + sanitized: list[Any] = [] + for block in content: + if isinstance(block, dict) and block.get("type") == "text" and "text" in block: + sanitized.append({"type": "text", "text": block["text"]}) + else: + sanitized.append(block) + return sanitized + + def _format_message_content(content: Any) -> Any: """Format message content for the Mistral chat completions wire format. @@ -484,7 +503,7 @@ def _convert_message_to_mistral_chat_message( if isinstance(message, ToolMessage): return { "role": "tool", - "content": message.content, + "content": _sanitize_chat_completions_content(message.content), "name": message.name, "tool_call_id": _convert_tool_call_id_to_mistral_compatible( message.tool_call_id diff --git a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py index b4cc3fd1531..90be8938812 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py @@ -16,6 +16,7 @@ from langchain_core.messages import ( InvalidToolCall, SystemMessage, ToolCall, + ToolMessage, ) from pydantic import SecretStr @@ -26,11 +27,30 @@ from langchain_mistralai.chat_models import ( # type: ignore[import] _convert_tool_call_id_to_mistral_compatible, _format_message_content, _is_valid_mistral_tool_call_id, + _sanitize_chat_completions_content, ) os.environ["MISTRAL_API_KEY"] = "foo" +def test_sanitize_chat_completions_text_blocks_strips_id() -> None: + """LangChain auto-generated `id` on text blocks must not reach the wire. + + Mistral's chat completions endpoint returns 422 with `extra_forbidden` + on `messages[*].tool.content.list[...].text.id` if not stripped. + """ + message = ToolMessage( + content=[{"type": "text", "text": "foo", "id": "lc_abc123"}], + tool_call_id="abc12345", + ) + result = _convert_message_to_mistral_chat_message(message) + assert result["content"] == [{"type": "text", "text": "foo"}] + + +def test_sanitize_chat_completions_content_passthrough_string() -> None: + assert _sanitize_chat_completions_content("hello") == "hello" + + def test_mistralai_model_param() -> None: llm = ChatMistralAI(model="foo") # type: ignore[call-arg] assert llm.model == "foo"