feat(mistralai): support reasoning feature and v1 content (#33485)

Not yet supported: server-side tool calls
This commit is contained in:
ccurme
2025-10-14 15:19:44 -04:00
committed by GitHub
parent 99e0a60aab
commit 9f4366bc9d
4 changed files with 273 additions and 38 deletions

View File

@@ -0,0 +1,125 @@
"""Derivations of standard content blocks from mistral content."""
from __future__ import annotations
from langchain_core.messages import AIMessage, AIMessageChunk
from langchain_core.messages import content as types
from langchain_core.messages.block_translators import register_translator
def _convert_from_v1_to_mistral(
content: list[types.ContentBlock],
model_provider: str | None,
) -> str | list[str | dict]:
new_content: list = []
for block in content:
if block["type"] == "text":
new_content.append({"text": block.get("text", ""), "type": "text"})
elif (
block["type"] == "reasoning"
and (reasoning := block.get("reasoning"))
and isinstance(reasoning, str)
and model_provider == "mistralai"
):
new_content.append(
{
"type": "thinking",
"thinking": [{"type": "text", "text": reasoning}],
}
)
elif (
block["type"] == "non_standard"
and "value" in block
and model_provider == "mistralai"
):
new_content.append(block["value"])
elif block["type"] == "tool_call":
continue
else:
new_content.append(block)
return new_content
def _convert_to_v1_from_mistral(message: AIMessage) -> list[types.ContentBlock]:
"""Convert mistral message content to v1 format."""
if isinstance(message.content, str):
content_blocks: list[types.ContentBlock] = [
{"type": "text", "text": message.content}
]
else:
content_blocks = []
for block in message.content:
if isinstance(block, str):
content_blocks.append({"type": "text", "text": block})
elif isinstance(block, dict):
if block.get("type") == "text" and isinstance(block.get("text"), str):
text_block: types.TextContentBlock = {
"type": "text",
"text": block["text"],
}
if "index" in block:
text_block["index"] = block["index"]
content_blocks.append(text_block)
elif block.get("type") == "thinking" and isinstance(
block.get("thinking"), list
):
for sub_block in block["thinking"]:
if (
isinstance(sub_block, dict)
and sub_block.get("type") == "text"
):
reasoning_block: types.ReasoningContentBlock = {
"type": "reasoning",
"reasoning": sub_block.get("text", ""),
}
if "index" in block:
reasoning_block["index"] = block["index"]
content_blocks.append(reasoning_block)
else:
non_standard_block: types.NonStandardContentBlock = {
"type": "non_standard",
"value": block,
}
content_blocks.append(non_standard_block)
else:
continue
if (
len(content_blocks) == 1
and content_blocks[0].get("type") == "text"
and content_blocks[0].get("text") == ""
and message.tool_calls
):
content_blocks = []
for tool_call in message.tool_calls:
content_blocks.append(
{
"type": "tool_call",
"name": tool_call["name"],
"args": tool_call["args"],
"id": tool_call.get("id"),
}
)
return content_blocks
def translate_content(message: AIMessage) -> list[types.ContentBlock]:
"""Derive standard content blocks from a message with mistral content."""
return _convert_to_v1_from_mistral(message)
def translate_content_chunk(message: AIMessageChunk) -> list[types.ContentBlock]:
"""Derive standard content blocks from a message chunk with mistral content."""
return _convert_to_v1_from_mistral(message)
register_translator("mistralai", translate_content, translate_content_chunk)

View File

@@ -24,12 +24,7 @@ from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
LangSmithParams,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.messages import (
AIMessage,
@@ -74,6 +69,8 @@ from pydantic import (
)
from typing_extensions import Self
from langchain_mistralai._compat import _convert_from_v1_to_mistral
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterator, Sequence
from contextlib import AbstractAsyncContextManager
@@ -160,6 +157,7 @@ def _convert_mistral_chat_message_to_message(
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
response_metadata={"model_provider": "mistralai"},
)
@@ -231,14 +229,34 @@ async def acompletion_with_retry(
def _convert_chunk_to_message_chunk(
chunk: dict, default_class: type[BaseMessageChunk]
) -> BaseMessageChunk:
chunk: dict,
default_class: type[BaseMessageChunk],
index: int,
index_type: str,
output_version: str | None,
) -> tuple[BaseMessageChunk, int, str]:
_choice = chunk["choices"][0]
_delta = _choice["delta"]
role = _delta.get("role")
content = _delta.get("content") or ""
if output_version == "v1" and isinstance(content, str):
content = [{"type": "text", "text": content}]
if isinstance(content, list):
for block in content:
if isinstance(block, dict):
if "type" in block and block["type"] != index_type:
index_type = block["type"]
index = index + 1
if "index" not in block:
block["index"] = index
if block.get("type") == "thinking" and isinstance(
block.get("thinking"), list
):
for sub_block in block["thinking"]:
if isinstance(sub_block, dict) and "index" not in sub_block:
sub_block["index"] = 0
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
return HumanMessageChunk(content=content), index, index_type
if role == "assistant" or default_class == AIMessageChunk:
additional_kwargs: dict = {}
response_metadata = {}
@@ -276,18 +294,22 @@ def _convert_chunk_to_message_chunk(
):
response_metadata["model_name"] = chunk["model"]
response_metadata["finish_reason"] = _choice["finish_reason"]
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
usage_metadata=usage_metadata, # type: ignore[arg-type]
response_metadata=response_metadata,
return (
AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
usage_metadata=usage_metadata, # type: ignore[arg-type]
response_metadata={"model_provider": "mistralai", **response_metadata},
),
index,
index_type,
)
if role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
return SystemMessageChunk(content=content), index, index_type
if role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
return default_class(content=content) # type: ignore[call-arg]
return ChatMessageChunk(content=content, role=role), index, index_type
return default_class(content=content), index, index_type # type: ignore[call-arg]
def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
@@ -318,6 +340,21 @@ def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) ->
return result
def _clean_block(block: dict) -> dict:
# Remove "index" key added for message aggregation in langchain-core
new_block = {k: v for k, v in block.items() if k != "index"}
if block.get("type") == "thinking" and isinstance(block.get("thinking"), list):
new_block["thinking"] = [
(
{k: v for k, v in sb.items() if k != "index"}
if isinstance(sb, dict) and "index" in sb
else sb
)
for sb in block["thinking"]
]
return new_block
def _convert_message_to_mistral_chat_message(
message: BaseMessage,
) -> dict:
@@ -356,13 +393,40 @@ def _convert_message_to_mistral_chat_message(
pass
if tool_calls: # do not populate empty list tool_calls
message_dict["tool_calls"] = tool_calls
if tool_calls and message.content:
# Message content
# Translate v1 content
if message.response_metadata.get("output_version") == "v1":
content = _convert_from_v1_to_mistral(
message.content_blocks, message.response_metadata.get("model_provider")
)
else:
content = message.content
if tool_calls and content:
# Assistant message must have either content or tool_calls, but not both.
# Some providers may not support tool_calls in the same message as content.
# This is done to ensure compatibility with messages from other providers.
message_dict["content"] = ""
content = ""
elif isinstance(content, list):
content = [
_clean_block(block)
if isinstance(block, dict) and "index" in block
else block
for block in content
]
else:
message_dict["content"] = message.content
content = message.content
# if any blocks are dicts, cast strings to text blocks
if any(isinstance(block, dict) for block in content):
content = [
block if isinstance(block, dict) else {"type": "text", "text": block}
for block in content
]
message_dict["content"] = content
if "prefix" in message.additional_kwargs:
message_dict["prefix"] = message.additional_kwargs["prefix"]
return message_dict
@@ -564,13 +628,6 @@ class ChatMistralAI(BaseChatModel):
stream: bool | None = None, # noqa: FBT001
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(
@@ -627,12 +684,16 @@ class ChatMistralAI(BaseChatModel):
params = {**params, **kwargs, "stream": True}
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
index = -1
index_type = ""
for chunk in self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
):
if len(chunk.get("choices", [])) == 0:
continue
new_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
new_chunk, index, index_type = _convert_chunk_to_message_chunk(
chunk, default_chunk_class, index, index_type, self.output_version
)
# make future chunks same type as first chunk
default_chunk_class = new_chunk.__class__
gen_chunk = ChatGenerationChunk(message=new_chunk)
@@ -653,12 +714,16 @@ class ChatMistralAI(BaseChatModel):
params = {**params, **kwargs, "stream": True}
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
index = -1
index_type = ""
async for chunk in await acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params
):
if len(chunk.get("choices", [])) == 0:
continue
new_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
new_chunk, index, index_type = _convert_chunk_to_message_chunk(
chunk, default_chunk_class, index, index_type, self.output_version
)
# make future chunks same type as first chunk
default_chunk_class = new_chunk.__class__
gen_chunk = ChatGenerationChunk(message=new_chunk)
@@ -676,13 +741,6 @@ class ChatMistralAI(BaseChatModel):
stream: bool | None = None, # noqa: FBT001
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = await acompletion_with_retry(

View File

@@ -28,7 +28,9 @@ async def test_astream() -> None:
full = token if full is None else full + token
if token.usage_metadata is not None:
chunks_with_token_counts += 1
if token.response_metadata:
if token.response_metadata and not set(token.response_metadata.keys()).issubset(
{"model_provider", "output_version"}
):
chunks_with_response_metadata += 1
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
msg = (
@@ -143,3 +145,51 @@ def test_retry_parameters(caplog: pytest.LogCaptureFixture) -> None:
except Exception:
logger.exception("Unexpected exception")
raise
def test_reasoning() -> None:
model = ChatMistralAI(model="magistral-medium-latest") # type: ignore[call-arg]
input_message = {
"role": "user",
"content": "Hello, my name is Bob.",
}
full: AIMessageChunk | None = None
for chunk in model.stream([input_message]):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
thinking_blocks = 0
for i, block in enumerate(full.content):
if isinstance(block, dict) and block.get("type") == "thinking":
thinking_blocks += 1
reasoning_block = full.content_blocks[i]
assert reasoning_block["type"] == "reasoning"
assert isinstance(reasoning_block.get("reasoning"), str)
assert thinking_blocks > 0
next_message = {"role": "user", "content": "What is my name?"}
_ = model.invoke([input_message, full, next_message])
def test_reasoning_v1() -> None:
model = ChatMistralAI(model="magistral-medium-latest", output_version="v1") # type: ignore[call-arg]
input_message = {
"role": "user",
"content": "Hello, my name is Bob.",
}
full: AIMessageChunk | None = None
chunks = []
for chunk in model.stream([input_message]):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
chunks.append(chunk)
assert isinstance(full, AIMessageChunk)
reasoning_blocks = 0
for block in full.content:
if isinstance(block, dict) and block.get("type") == "reasoning":
reasoning_blocks += 1
assert isinstance(block.get("reasoning"), str)
assert reasoning_blocks > 0
next_message = {"role": "user", "content": "What is my name?"}
_ = model.invoke([input_message, full, next_message])

View File

@@ -188,6 +188,7 @@ def test__convert_dict_to_message_tool_call() -> None:
type="tool_call",
)
],
response_metadata={"model_provider": "mistralai"},
)
assert result == expected_output
assert _convert_message_to_mistral_chat_message(expected_output) == message
@@ -231,6 +232,7 @@ def test__convert_dict_to_message_tool_call() -> None:
type="tool_call",
),
],
response_metadata={"model_provider": "mistralai"},
)
assert result == expected_output
assert _convert_message_to_mistral_chat_message(expected_output) == message