diff --git a/libs/partners/mistralai/langchain_mistralai/_compat.py b/libs/partners/mistralai/langchain_mistralai/_compat.py new file mode 100644 index 00000000000..f716300dccc --- /dev/null +++ b/libs/partners/mistralai/langchain_mistralai/_compat.py @@ -0,0 +1,125 @@ +"""Derivations of standard content blocks from mistral content.""" + +from __future__ import annotations + +from langchain_core.messages import AIMessage, AIMessageChunk +from langchain_core.messages import content as types +from langchain_core.messages.block_translators import register_translator + + +def _convert_from_v1_to_mistral( + content: list[types.ContentBlock], + model_provider: str | None, +) -> str | list[str | dict]: + new_content: list = [] + for block in content: + if block["type"] == "text": + new_content.append({"text": block.get("text", ""), "type": "text"}) + + elif ( + block["type"] == "reasoning" + and (reasoning := block.get("reasoning")) + and isinstance(reasoning, str) + and model_provider == "mistralai" + ): + new_content.append( + { + "type": "thinking", + "thinking": [{"type": "text", "text": reasoning}], + } + ) + + elif ( + block["type"] == "non_standard" + and "value" in block + and model_provider == "mistralai" + ): + new_content.append(block["value"]) + elif block["type"] == "tool_call": + continue + else: + new_content.append(block) + + return new_content + + +def _convert_to_v1_from_mistral(message: AIMessage) -> list[types.ContentBlock]: + """Convert mistral message content to v1 format.""" + if isinstance(message.content, str): + content_blocks: list[types.ContentBlock] = [ + {"type": "text", "text": message.content} + ] + + else: + content_blocks = [] + for block in message.content: + if isinstance(block, str): + content_blocks.append({"type": "text", "text": block}) + + elif isinstance(block, dict): + if block.get("type") == "text" and isinstance(block.get("text"), str): + text_block: types.TextContentBlock = { + "type": "text", + "text": block["text"], + } + if "index" in block: + text_block["index"] = block["index"] + content_blocks.append(text_block) + + elif block.get("type") == "thinking" and isinstance( + block.get("thinking"), list + ): + for sub_block in block["thinking"]: + if ( + isinstance(sub_block, dict) + and sub_block.get("type") == "text" + ): + reasoning_block: types.ReasoningContentBlock = { + "type": "reasoning", + "reasoning": sub_block.get("text", ""), + } + if "index" in block: + reasoning_block["index"] = block["index"] + content_blocks.append(reasoning_block) + + else: + non_standard_block: types.NonStandardContentBlock = { + "type": "non_standard", + "value": block, + } + content_blocks.append(non_standard_block) + else: + continue + + if ( + len(content_blocks) == 1 + and content_blocks[0].get("type") == "text" + and content_blocks[0].get("text") == "" + and message.tool_calls + ): + content_blocks = [] + + for tool_call in message.tool_calls: + content_blocks.append( + { + "type": "tool_call", + "name": tool_call["name"], + "args": tool_call["args"], + "id": tool_call.get("id"), + } + ) + + return content_blocks + + +def translate_content(message: AIMessage) -> list[types.ContentBlock]: + """Derive standard content blocks from a message with mistral content.""" + return _convert_to_v1_from_mistral(message) + + +def translate_content_chunk(message: AIMessageChunk) -> list[types.ContentBlock]: + """Derive standard content blocks from a message chunk with mistral content.""" + return _convert_to_v1_from_mistral(message) + + +register_translator("mistralai", translate_content, translate_content_chunk) diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 578f0acdaa5..c9ca6c02ca5 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -24,12 +24,7 @@ from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) from langchain_core.language_models import LanguageModelInput -from langchain_core.language_models.chat_models import ( - BaseChatModel, - LangSmithParams, - agenerate_from_stream, - generate_from_stream, -) +from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams from langchain_core.language_models.llms import create_base_retry_decorator from langchain_core.messages import ( AIMessage, @@ -74,6 +69,8 @@ from pydantic import ( ) from typing_extensions import Self +from langchain_mistralai._compat import _convert_from_v1_to_mistral + if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterator, Sequence from contextlib import AbstractAsyncContextManager @@ -160,6 +157,7 @@ def _convert_mistral_chat_message_to_message( additional_kwargs=additional_kwargs, tool_calls=tool_calls, invalid_tool_calls=invalid_tool_calls, + response_metadata={"model_provider": "mistralai"}, ) @@ -231,14 +229,34 @@ async def acompletion_with_retry( def _convert_chunk_to_message_chunk( - chunk: dict, default_class: type[BaseMessageChunk] -) -> BaseMessageChunk: + chunk: dict, + default_class: type[BaseMessageChunk], + index: int, + index_type: str, + output_version: str | None, +) -> tuple[BaseMessageChunk, int, str]: _choice = chunk["choices"][0] _delta = _choice["delta"] role = _delta.get("role") content = _delta.get("content") or "" + if output_version == "v1" and isinstance(content, str): + content = [{"type": "text", "text": content}] + if isinstance(content, list): + for block in content: + if isinstance(block, dict): + if "type" in block and block["type"] != index_type: + index_type = block["type"] + index = index + 1 + if "index" not in block: + block["index"] = index + if block.get("type") == "thinking" and isinstance( + block.get("thinking"), list + ): + for sub_block in block["thinking"]: + if isinstance(sub_block, dict) and "index" not in sub_block: + sub_block["index"] = 0 if role == "user" or default_class == HumanMessageChunk: - return HumanMessageChunk(content=content) + return HumanMessageChunk(content=content), index, index_type if role == "assistant" or default_class == AIMessageChunk: additional_kwargs: dict = {} response_metadata = {} @@ -276,18 +294,22 @@ def _convert_chunk_to_message_chunk( ): response_metadata["model_name"] = chunk["model"] response_metadata["finish_reason"] = _choice["finish_reason"] - return AIMessageChunk( - content=content, - additional_kwargs=additional_kwargs, - tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] - usage_metadata=usage_metadata, # type: ignore[arg-type] - response_metadata=response_metadata, + return ( + AIMessageChunk( + content=content, + additional_kwargs=additional_kwargs, + tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] + usage_metadata=usage_metadata, # type: ignore[arg-type] + response_metadata={"model_provider": "mistralai", **response_metadata}, + ), + index, + index_type, ) if role == "system" or default_class == SystemMessageChunk: - return SystemMessageChunk(content=content) + return SystemMessageChunk(content=content), index, index_type if role or default_class == ChatMessageChunk: - return ChatMessageChunk(content=content, role=role) - return default_class(content=content) # type: ignore[call-arg] + return ChatMessageChunk(content=content, role=role), index, index_type + return default_class(content=content), index, index_type # type: ignore[call-arg] def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict: @@ -318,6 +340,21 @@ def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) -> return result +def _clean_block(block: dict) -> dict: + # Remove "index" key added for message aggregation in langchain-core + new_block = {k: v for k, v in block.items() if k != "index"} + if block.get("type") == "thinking" and isinstance(block.get("thinking"), list): + new_block["thinking"] = [ + ( + {k: v for k, v in sb.items() if k != "index"} + if isinstance(sb, dict) and "index" in sb + else sb + ) + for sb in block["thinking"] + ] + return new_block + + def _convert_message_to_mistral_chat_message( message: BaseMessage, ) -> dict: @@ -356,13 +393,40 @@ def _convert_message_to_mistral_chat_message( pass if tool_calls: # do not populate empty list tool_calls message_dict["tool_calls"] = tool_calls - if tool_calls and message.content: + + # Message content + # Translate v1 content + if message.response_metadata.get("output_version") == "v1": + content = _convert_from_v1_to_mistral( + message.content_blocks, message.response_metadata.get("model_provider") + ) + else: + content = message.content + + if tool_calls and content: # Assistant message must have either content or tool_calls, but not both. # Some providers may not support tool_calls in the same message as content. # This is done to ensure compatibility with messages from other providers. - message_dict["content"] = "" + content = "" + + elif isinstance(content, list): + content = [ + _clean_block(block) + if isinstance(block, dict) and "index" in block + else block + for block in content + ] else: - message_dict["content"] = message.content + content = message.content + + # if any blocks are dicts, cast strings to text blocks + if any(isinstance(block, dict) for block in content): + content = [ + block if isinstance(block, dict) else {"type": "text", "text": block} + for block in content + ] + message_dict["content"] = content + if "prefix" in message.additional_kwargs: message_dict["prefix"] = message.additional_kwargs["prefix"] return message_dict @@ -564,13 +628,6 @@ class ChatMistralAI(BaseChatModel): stream: bool | None = None, # noqa: FBT001 **kwargs: Any, ) -> ChatResult: - should_stream = stream if stream is not None else self.streaming - if should_stream: - stream_iter = self._stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return generate_from_stream(stream_iter) - message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs} response = self.completion_with_retry( @@ -627,12 +684,16 @@ class ChatMistralAI(BaseChatModel): params = {**params, **kwargs, "stream": True} default_chunk_class: type[BaseMessageChunk] = AIMessageChunk + index = -1 + index_type = "" for chunk in self.completion_with_retry( messages=message_dicts, run_manager=run_manager, **params ): if len(chunk.get("choices", [])) == 0: continue - new_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class) + new_chunk, index, index_type = _convert_chunk_to_message_chunk( + chunk, default_chunk_class, index, index_type, self.output_version + ) # make future chunks same type as first chunk default_chunk_class = new_chunk.__class__ gen_chunk = ChatGenerationChunk(message=new_chunk) @@ -653,12 +714,16 @@ class ChatMistralAI(BaseChatModel): params = {**params, **kwargs, "stream": True} default_chunk_class: type[BaseMessageChunk] = AIMessageChunk + index = -1 + index_type = "" async for chunk in await acompletion_with_retry( self, messages=message_dicts, run_manager=run_manager, **params ): if len(chunk.get("choices", [])) == 0: continue - new_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class) + new_chunk, index, index_type = _convert_chunk_to_message_chunk( + chunk, default_chunk_class, index, index_type, self.output_version + ) # make future chunks same type as first chunk default_chunk_class = new_chunk.__class__ gen_chunk = ChatGenerationChunk(message=new_chunk) @@ -676,13 +741,6 @@ class ChatMistralAI(BaseChatModel): stream: bool | None = None, # noqa: FBT001 **kwargs: Any, ) -> ChatResult: - should_stream = stream if stream is not None else self.streaming - if should_stream: - stream_iter = self._astream( - messages=messages, stop=stop, run_manager=run_manager, **kwargs - ) - return await agenerate_from_stream(stream_iter) - message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs} response = await acompletion_with_retry( diff --git a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py index e4a956e269e..bef96ae3e14 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py @@ -28,7 +28,9 @@ async def test_astream() -> None: full = token if full is None else full + token if token.usage_metadata is not None: chunks_with_token_counts += 1 - if token.response_metadata: + if token.response_metadata and not set(token.response_metadata.keys()).issubset( + {"model_provider", "output_version"} + ): chunks_with_response_metadata += 1 if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1: msg = ( @@ -143,3 +145,51 @@ def test_retry_parameters(caplog: pytest.LogCaptureFixture) -> None: except Exception: logger.exception("Unexpected exception") raise + + +def test_reasoning() -> None: + model = ChatMistralAI(model="magistral-medium-latest") # type: ignore[call-arg] + input_message = { + "role": "user", + "content": "Hello, my name is Bob.", + } + full: AIMessageChunk | None = None + for chunk in model.stream([input_message]): + assert isinstance(chunk, AIMessageChunk) + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + thinking_blocks = 0 + for i, block in enumerate(full.content): + if isinstance(block, dict) and block.get("type") == "thinking": + thinking_blocks += 1 + reasoning_block = full.content_blocks[i] + assert reasoning_block["type"] == "reasoning" + assert isinstance(reasoning_block.get("reasoning"), str) + assert thinking_blocks > 0 + + next_message = {"role": "user", "content": "What is my name?"} + _ = model.invoke([input_message, full, next_message]) + + +def test_reasoning_v1() -> None: + model = ChatMistralAI(model="magistral-medium-latest", output_version="v1") # type: ignore[call-arg] + input_message = { + "role": "user", + "content": "Hello, my name is Bob.", + } + full: AIMessageChunk | None = None + chunks = [] + for chunk in model.stream([input_message]): + assert isinstance(chunk, AIMessageChunk) + full = chunk if full is None else full + chunk + chunks.append(chunk) + assert isinstance(full, AIMessageChunk) + reasoning_blocks = 0 + for block in full.content: + if isinstance(block, dict) and block.get("type") == "reasoning": + reasoning_blocks += 1 + assert isinstance(block.get("reasoning"), str) + assert reasoning_blocks > 0 + + next_message = {"role": "user", "content": "What is my name?"} + _ = model.invoke([input_message, full, next_message]) diff --git a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py index 7851275ed94..e582d12156a 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py @@ -188,6 +188,7 @@ def test__convert_dict_to_message_tool_call() -> None: type="tool_call", ) ], + response_metadata={"model_provider": "mistralai"}, ) assert result == expected_output assert _convert_message_to_mistral_chat_message(expected_output) == message @@ -231,6 +232,7 @@ def test__convert_dict_to_message_tool_call() -> None: type="tool_call", ), ], + response_metadata={"model_provider": "mistralai"}, ) assert result == expected_output assert _convert_message_to_mistral_chat_message(expected_output) == message