mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 11:07:36 +00:00
feat(mistralai): support reasoning feature and v1 content (#33485)
Not yet supported: server-side tool calls
This commit is contained in:
125
libs/partners/mistralai/langchain_mistralai/_compat.py
Normal file
125
libs/partners/mistralai/langchain_mistralai/_compat.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Derivations of standard content blocks from mistral content."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages import content as types
|
||||
from langchain_core.messages.block_translators import register_translator
|
||||
|
||||
|
||||
def _convert_from_v1_to_mistral(
|
||||
content: list[types.ContentBlock],
|
||||
model_provider: str | None,
|
||||
) -> str | list[str | dict]:
|
||||
new_content: list = []
|
||||
for block in content:
|
||||
if block["type"] == "text":
|
||||
new_content.append({"text": block.get("text", ""), "type": "text"})
|
||||
|
||||
elif (
|
||||
block["type"] == "reasoning"
|
||||
and (reasoning := block.get("reasoning"))
|
||||
and isinstance(reasoning, str)
|
||||
and model_provider == "mistralai"
|
||||
):
|
||||
new_content.append(
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": [{"type": "text", "text": reasoning}],
|
||||
}
|
||||
)
|
||||
|
||||
elif (
|
||||
block["type"] == "non_standard"
|
||||
and "value" in block
|
||||
and model_provider == "mistralai"
|
||||
):
|
||||
new_content.append(block["value"])
|
||||
elif block["type"] == "tool_call":
|
||||
continue
|
||||
else:
|
||||
new_content.append(block)
|
||||
|
||||
return new_content
|
||||
|
||||
|
||||
def _convert_to_v1_from_mistral(message: AIMessage) -> list[types.ContentBlock]:
|
||||
"""Convert mistral message content to v1 format."""
|
||||
if isinstance(message.content, str):
|
||||
content_blocks: list[types.ContentBlock] = [
|
||||
{"type": "text", "text": message.content}
|
||||
]
|
||||
|
||||
else:
|
||||
content_blocks = []
|
||||
for block in message.content:
|
||||
if isinstance(block, str):
|
||||
content_blocks.append({"type": "text", "text": block})
|
||||
|
||||
elif isinstance(block, dict):
|
||||
if block.get("type") == "text" and isinstance(block.get("text"), str):
|
||||
text_block: types.TextContentBlock = {
|
||||
"type": "text",
|
||||
"text": block["text"],
|
||||
}
|
||||
if "index" in block:
|
||||
text_block["index"] = block["index"]
|
||||
content_blocks.append(text_block)
|
||||
|
||||
elif block.get("type") == "thinking" and isinstance(
|
||||
block.get("thinking"), list
|
||||
):
|
||||
for sub_block in block["thinking"]:
|
||||
if (
|
||||
isinstance(sub_block, dict)
|
||||
and sub_block.get("type") == "text"
|
||||
):
|
||||
reasoning_block: types.ReasoningContentBlock = {
|
||||
"type": "reasoning",
|
||||
"reasoning": sub_block.get("text", ""),
|
||||
}
|
||||
if "index" in block:
|
||||
reasoning_block["index"] = block["index"]
|
||||
content_blocks.append(reasoning_block)
|
||||
|
||||
else:
|
||||
non_standard_block: types.NonStandardContentBlock = {
|
||||
"type": "non_standard",
|
||||
"value": block,
|
||||
}
|
||||
content_blocks.append(non_standard_block)
|
||||
else:
|
||||
continue
|
||||
|
||||
if (
|
||||
len(content_blocks) == 1
|
||||
and content_blocks[0].get("type") == "text"
|
||||
and content_blocks[0].get("text") == ""
|
||||
and message.tool_calls
|
||||
):
|
||||
content_blocks = []
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "tool_call",
|
||||
"name": tool_call["name"],
|
||||
"args": tool_call["args"],
|
||||
"id": tool_call.get("id"),
|
||||
}
|
||||
)
|
||||
|
||||
return content_blocks
|
||||
|
||||
|
||||
def translate_content(message: AIMessage) -> list[types.ContentBlock]:
|
||||
"""Derive standard content blocks from a message with mistral content."""
|
||||
return _convert_to_v1_from_mistral(message)
|
||||
|
||||
|
||||
def translate_content_chunk(message: AIMessageChunk) -> list[types.ContentBlock]:
|
||||
"""Derive standard content blocks from a message chunk with mistral content."""
|
||||
return _convert_to_v1_from_mistral(message)
|
||||
|
||||
|
||||
register_translator("mistralai", translate_content, translate_content_chunk)
|
||||
@@ -24,12 +24,7 @@ from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
LangSmithParams,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
|
||||
from langchain_core.language_models.llms import create_base_retry_decorator
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@@ -74,6 +69,8 @@ from pydantic import (
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_mistralai._compat import _convert_from_v1_to_mistral
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
@@ -160,6 +157,7 @@ def _convert_mistral_chat_message_to_message(
|
||||
additional_kwargs=additional_kwargs,
|
||||
tool_calls=tool_calls,
|
||||
invalid_tool_calls=invalid_tool_calls,
|
||||
response_metadata={"model_provider": "mistralai"},
|
||||
)
|
||||
|
||||
|
||||
@@ -231,14 +229,34 @@ async def acompletion_with_retry(
|
||||
|
||||
|
||||
def _convert_chunk_to_message_chunk(
|
||||
chunk: dict, default_class: type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
chunk: dict,
|
||||
default_class: type[BaseMessageChunk],
|
||||
index: int,
|
||||
index_type: str,
|
||||
output_version: str | None,
|
||||
) -> tuple[BaseMessageChunk, int, str]:
|
||||
_choice = chunk["choices"][0]
|
||||
_delta = _choice["delta"]
|
||||
role = _delta.get("role")
|
||||
content = _delta.get("content") or ""
|
||||
if output_version == "v1" and isinstance(content, str):
|
||||
content = [{"type": "text", "text": content}]
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict):
|
||||
if "type" in block and block["type"] != index_type:
|
||||
index_type = block["type"]
|
||||
index = index + 1
|
||||
if "index" not in block:
|
||||
block["index"] = index
|
||||
if block.get("type") == "thinking" and isinstance(
|
||||
block.get("thinking"), list
|
||||
):
|
||||
for sub_block in block["thinking"]:
|
||||
if isinstance(sub_block, dict) and "index" not in sub_block:
|
||||
sub_block["index"] = 0
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
return HumanMessageChunk(content=content), index, index_type
|
||||
if role == "assistant" or default_class == AIMessageChunk:
|
||||
additional_kwargs: dict = {}
|
||||
response_metadata = {}
|
||||
@@ -276,18 +294,22 @@ def _convert_chunk_to_message_chunk(
|
||||
):
|
||||
response_metadata["model_name"] = chunk["model"]
|
||||
response_metadata["finish_reason"] = _choice["finish_reason"]
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
|
||||
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
||||
response_metadata=response_metadata,
|
||||
return (
|
||||
AIMessageChunk(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
|
||||
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
||||
response_metadata={"model_provider": "mistralai", **response_metadata},
|
||||
),
|
||||
index,
|
||||
index_type,
|
||||
)
|
||||
if role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
return SystemMessageChunk(content=content), index, index_type
|
||||
if role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
return default_class(content=content) # type: ignore[call-arg]
|
||||
return ChatMessageChunk(content=content, role=role), index, index_type
|
||||
return default_class(content=content), index, index_type # type: ignore[call-arg]
|
||||
|
||||
|
||||
def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
|
||||
@@ -318,6 +340,21 @@ def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) ->
|
||||
return result
|
||||
|
||||
|
||||
def _clean_block(block: dict) -> dict:
|
||||
# Remove "index" key added for message aggregation in langchain-core
|
||||
new_block = {k: v for k, v in block.items() if k != "index"}
|
||||
if block.get("type") == "thinking" and isinstance(block.get("thinking"), list):
|
||||
new_block["thinking"] = [
|
||||
(
|
||||
{k: v for k, v in sb.items() if k != "index"}
|
||||
if isinstance(sb, dict) and "index" in sb
|
||||
else sb
|
||||
)
|
||||
for sb in block["thinking"]
|
||||
]
|
||||
return new_block
|
||||
|
||||
|
||||
def _convert_message_to_mistral_chat_message(
|
||||
message: BaseMessage,
|
||||
) -> dict:
|
||||
@@ -356,13 +393,40 @@ def _convert_message_to_mistral_chat_message(
|
||||
pass
|
||||
if tool_calls: # do not populate empty list tool_calls
|
||||
message_dict["tool_calls"] = tool_calls
|
||||
if tool_calls and message.content:
|
||||
|
||||
# Message content
|
||||
# Translate v1 content
|
||||
if message.response_metadata.get("output_version") == "v1":
|
||||
content = _convert_from_v1_to_mistral(
|
||||
message.content_blocks, message.response_metadata.get("model_provider")
|
||||
)
|
||||
else:
|
||||
content = message.content
|
||||
|
||||
if tool_calls and content:
|
||||
# Assistant message must have either content or tool_calls, but not both.
|
||||
# Some providers may not support tool_calls in the same message as content.
|
||||
# This is done to ensure compatibility with messages from other providers.
|
||||
message_dict["content"] = ""
|
||||
content = ""
|
||||
|
||||
elif isinstance(content, list):
|
||||
content = [
|
||||
_clean_block(block)
|
||||
if isinstance(block, dict) and "index" in block
|
||||
else block
|
||||
for block in content
|
||||
]
|
||||
else:
|
||||
message_dict["content"] = message.content
|
||||
content = message.content
|
||||
|
||||
# if any blocks are dicts, cast strings to text blocks
|
||||
if any(isinstance(block, dict) for block in content):
|
||||
content = [
|
||||
block if isinstance(block, dict) else {"type": "text", "text": block}
|
||||
for block in content
|
||||
]
|
||||
message_dict["content"] = content
|
||||
|
||||
if "prefix" in message.additional_kwargs:
|
||||
message_dict["prefix"] = message.additional_kwargs["prefix"]
|
||||
return message_dict
|
||||
@@ -564,13 +628,6 @@ class ChatMistralAI(BaseChatModel):
|
||||
stream: bool | None = None, # noqa: FBT001
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = self.completion_with_retry(
|
||||
@@ -627,12 +684,16 @@ class ChatMistralAI(BaseChatModel):
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
index = -1
|
||||
index_type = ""
|
||||
for chunk in self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
if len(chunk.get("choices", [])) == 0:
|
||||
continue
|
||||
new_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
||||
new_chunk, index, index_type = _convert_chunk_to_message_chunk(
|
||||
chunk, default_chunk_class, index, index_type, self.output_version
|
||||
)
|
||||
# make future chunks same type as first chunk
|
||||
default_chunk_class = new_chunk.__class__
|
||||
gen_chunk = ChatGenerationChunk(message=new_chunk)
|
||||
@@ -653,12 +714,16 @@ class ChatMistralAI(BaseChatModel):
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
|
||||
index = -1
|
||||
index_type = ""
|
||||
async for chunk in await acompletion_with_retry(
|
||||
self, messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
if len(chunk.get("choices", [])) == 0:
|
||||
continue
|
||||
new_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
|
||||
new_chunk, index, index_type = _convert_chunk_to_message_chunk(
|
||||
chunk, default_chunk_class, index, index_type, self.output_version
|
||||
)
|
||||
# make future chunks same type as first chunk
|
||||
default_chunk_class = new_chunk.__class__
|
||||
gen_chunk = ChatGenerationChunk(message=new_chunk)
|
||||
@@ -676,13 +741,6 @@ class ChatMistralAI(BaseChatModel):
|
||||
stream: bool | None = None, # noqa: FBT001
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._astream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = await acompletion_with_retry(
|
||||
|
||||
@@ -28,7 +28,9 @@ async def test_astream() -> None:
|
||||
full = token if full is None else full + token
|
||||
if token.usage_metadata is not None:
|
||||
chunks_with_token_counts += 1
|
||||
if token.response_metadata:
|
||||
if token.response_metadata and not set(token.response_metadata.keys()).issubset(
|
||||
{"model_provider", "output_version"}
|
||||
):
|
||||
chunks_with_response_metadata += 1
|
||||
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
|
||||
msg = (
|
||||
@@ -143,3 +145,51 @@ def test_retry_parameters(caplog: pytest.LogCaptureFixture) -> None:
|
||||
except Exception:
|
||||
logger.exception("Unexpected exception")
|
||||
raise
|
||||
|
||||
|
||||
def test_reasoning() -> None:
|
||||
model = ChatMistralAI(model="magistral-medium-latest") # type: ignore[call-arg]
|
||||
input_message = {
|
||||
"role": "user",
|
||||
"content": "Hello, my name is Bob.",
|
||||
}
|
||||
full: AIMessageChunk | None = None
|
||||
for chunk in model.stream([input_message]):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
thinking_blocks = 0
|
||||
for i, block in enumerate(full.content):
|
||||
if isinstance(block, dict) and block.get("type") == "thinking":
|
||||
thinking_blocks += 1
|
||||
reasoning_block = full.content_blocks[i]
|
||||
assert reasoning_block["type"] == "reasoning"
|
||||
assert isinstance(reasoning_block.get("reasoning"), str)
|
||||
assert thinking_blocks > 0
|
||||
|
||||
next_message = {"role": "user", "content": "What is my name?"}
|
||||
_ = model.invoke([input_message, full, next_message])
|
||||
|
||||
|
||||
def test_reasoning_v1() -> None:
|
||||
model = ChatMistralAI(model="magistral-medium-latest", output_version="v1") # type: ignore[call-arg]
|
||||
input_message = {
|
||||
"role": "user",
|
||||
"content": "Hello, my name is Bob.",
|
||||
}
|
||||
full: AIMessageChunk | None = None
|
||||
chunks = []
|
||||
for chunk in model.stream([input_message]):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
full = chunk if full is None else full + chunk
|
||||
chunks.append(chunk)
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
reasoning_blocks = 0
|
||||
for block in full.content:
|
||||
if isinstance(block, dict) and block.get("type") == "reasoning":
|
||||
reasoning_blocks += 1
|
||||
assert isinstance(block.get("reasoning"), str)
|
||||
assert reasoning_blocks > 0
|
||||
|
||||
next_message = {"role": "user", "content": "What is my name?"}
|
||||
_ = model.invoke([input_message, full, next_message])
|
||||
|
||||
@@ -188,6 +188,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
type="tool_call",
|
||||
)
|
||||
],
|
||||
response_metadata={"model_provider": "mistralai"},
|
||||
)
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_mistral_chat_message(expected_output) == message
|
||||
@@ -231,6 +232,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
type="tool_call",
|
||||
),
|
||||
],
|
||||
response_metadata={"model_provider": "mistralai"},
|
||||
)
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_mistral_chat_message(expected_output) == message
|
||||
|
||||
Reference in New Issue
Block a user