feat(mistralai): support reasoning feature and v1 content (#33485)

Not yet supported: server-side tool calls
This commit is contained in:
ccurme
2025-10-14 15:19:44 -04:00
committed by GitHub
parent 99e0a60aab
commit 9f4366bc9d
4 changed files with 273 additions and 38 deletions

View File

@@ -24,12 +24,7 @@ from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
LangSmithParams,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.messages import (
AIMessage,
@@ -74,6 +69,8 @@ from pydantic import (
)
from typing_extensions import Self
from langchain_mistralai._compat import _convert_from_v1_to_mistral
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterator, Sequence
from contextlib import AbstractAsyncContextManager
@@ -160,6 +157,7 @@ def _convert_mistral_chat_message_to_message(
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
response_metadata={"model_provider": "mistralai"},
)
@@ -231,14 +229,34 @@ async def acompletion_with_retry(
def _convert_chunk_to_message_chunk(
chunk: dict, default_class: type[BaseMessageChunk]
) -> BaseMessageChunk:
chunk: dict,
default_class: type[BaseMessageChunk],
index: int,
index_type: str,
output_version: str | None,
) -> tuple[BaseMessageChunk, int, str]:
_choice = chunk["choices"][0]
_delta = _choice["delta"]
role = _delta.get("role")
content = _delta.get("content") or ""
if output_version == "v1" and isinstance(content, str):
content = [{"type": "text", "text": content}]
if isinstance(content, list):
for block in content:
if isinstance(block, dict):
if "type" in block and block["type"] != index_type:
index_type = block["type"]
index = index + 1
if "index" not in block:
block["index"] = index
if block.get("type") == "thinking" and isinstance(
block.get("thinking"), list
):
for sub_block in block["thinking"]:
if isinstance(sub_block, dict) and "index" not in sub_block:
sub_block["index"] = 0
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
return HumanMessageChunk(content=content), index, index_type
if role == "assistant" or default_class == AIMessageChunk:
additional_kwargs: dict = {}
response_metadata = {}
@@ -276,18 +294,22 @@ def _convert_chunk_to_message_chunk(
):
response_metadata["model_name"] = chunk["model"]
response_metadata["finish_reason"] = _choice["finish_reason"]
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
usage_metadata=usage_metadata, # type: ignore[arg-type]
response_metadata=response_metadata,
return (
AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
usage_metadata=usage_metadata, # type: ignore[arg-type]
response_metadata={"model_provider": "mistralai", **response_metadata},
),
index,
index_type,
)
if role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
return SystemMessageChunk(content=content), index, index_type
if role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
return default_class(content=content) # type: ignore[call-arg]
return ChatMessageChunk(content=content, role=role), index, index_type
return default_class(content=content), index, index_type # type: ignore[call-arg]
def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
@@ -318,6 +340,21 @@ def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) ->
return result
def _clean_block(block: dict) -> dict:
# Remove "index" key added for message aggregation in langchain-core
new_block = {k: v for k, v in block.items() if k != "index"}
if block.get("type") == "thinking" and isinstance(block.get("thinking"), list):
new_block["thinking"] = [
(
{k: v for k, v in sb.items() if k != "index"}
if isinstance(sb, dict) and "index" in sb
else sb
)
for sb in block["thinking"]
]
return new_block
def _convert_message_to_mistral_chat_message(
message: BaseMessage,
) -> dict:
@@ -356,13 +393,40 @@ def _convert_message_to_mistral_chat_message(
pass
if tool_calls: # do not populate empty list tool_calls
message_dict["tool_calls"] = tool_calls
if tool_calls and message.content:
# Message content
# Translate v1 content
if message.response_metadata.get("output_version") == "v1":
content = _convert_from_v1_to_mistral(
message.content_blocks, message.response_metadata.get("model_provider")
)
else:
content = message.content
if tool_calls and content:
# Assistant message must have either content or tool_calls, but not both.
# Some providers may not support tool_calls in the same message as content.
# This is done to ensure compatibility with messages from other providers.
message_dict["content"] = ""
content = ""
elif isinstance(content, list):
content = [
_clean_block(block)
if isinstance(block, dict) and "index" in block
else block
for block in content
]
else:
message_dict["content"] = message.content
content = message.content
# if any blocks are dicts, cast strings to text blocks
if any(isinstance(block, dict) for block in content):
content = [
block if isinstance(block, dict) else {"type": "text", "text": block}
for block in content
]
message_dict["content"] = content
if "prefix" in message.additional_kwargs:
message_dict["prefix"] = message.additional_kwargs["prefix"]
return message_dict
@@ -564,13 +628,6 @@ class ChatMistralAI(BaseChatModel):
stream: bool | None = None, # noqa: FBT001
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(
@@ -627,12 +684,16 @@ class ChatMistralAI(BaseChatModel):
params = {**params, **kwargs, "stream": True}
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
index = -1
index_type = ""
for chunk in self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
):
if len(chunk.get("choices", [])) == 0:
continue
new_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
new_chunk, index, index_type = _convert_chunk_to_message_chunk(
chunk, default_chunk_class, index, index_type, self.output_version
)
# make future chunks same type as first chunk
default_chunk_class = new_chunk.__class__
gen_chunk = ChatGenerationChunk(message=new_chunk)
@@ -653,12 +714,16 @@ class ChatMistralAI(BaseChatModel):
params = {**params, **kwargs, "stream": True}
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
index = -1
index_type = ""
async for chunk in await acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params
):
if len(chunk.get("choices", [])) == 0:
continue
new_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
new_chunk, index, index_type = _convert_chunk_to_message_chunk(
chunk, default_chunk_class, index, index_type, self.output_version
)
# make future chunks same type as first chunk
default_chunk_class = new_chunk.__class__
gen_chunk = ChatGenerationChunk(message=new_chunk)
@@ -676,13 +741,6 @@ class ChatMistralAI(BaseChatModel):
stream: bool | None = None, # noqa: FBT001
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = await acompletion_with_retry(