Files
langchain/libs/partners/mistralai/langchain_mistralai/chat_models.py

1223 lines
46 KiB
Python

from __future__ import annotations
import hashlib
import json
import logging
import os
import re
import ssl
import uuid
from collections.abc import Callable # noqa: TC003
from operator import itemgetter
from typing import (
TYPE_CHECKING,
Any,
Literal,
cast,
)
import certifi
import httpx
from httpx_sse import EventSource, aconnect_sse, connect_sse
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import (
LanguageModelInput,
ModelProfile,
ModelProfileRegistry,
)
from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
HumanMessage,
HumanMessageChunk,
InvalidToolCall,
SystemMessage,
SystemMessageChunk,
ToolCall,
ToolMessage,
)
from langchain_core.messages.tool import tool_call_chunk
from langchain_core.output_parsers import (
JsonOutputParser,
PydanticOutputParser,
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
make_invalid_tool_call,
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import get_pydantic_field_names, secret_from_env
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_core.utils.utils import _build_model_kwargs
from pydantic import (
BaseModel,
ConfigDict,
Field,
SecretStr,
model_validator,
)
from typing_extensions import Self
from langchain_mistralai._compat import _convert_from_v1_to_mistral
from langchain_mistralai.data._profiles import _PROFILES
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterator, Sequence
from contextlib import AbstractAsyncContextManager
logger = logging.getLogger(__name__)
# Mistral enforces a specific pattern for tool call IDs
TOOL_CALL_ID_PATTERN = re.compile(r"^[a-zA-Z0-9]{9}$")
# This SSL context is equivalent to the default `verify=True`.
# https://www.python-httpx.org/advanced/ssl/#configuring-client-instances
global_ssl_context = ssl.create_default_context(cafile=certifi.where())
_MODEL_PROFILES = cast("ModelProfileRegistry", _PROFILES)
def _get_default_model_profile(model_name: str) -> ModelProfile:
default = _MODEL_PROFILES.get(model_name) or {}
return default.copy()
def _create_retry_decorator(
llm: ChatMistralAI,
run_manager: AsyncCallbackManagerForLLMRun | CallbackManagerForLLMRun | None = None,
) -> Callable[[Any], Any]:
"""Return a tenacity retry decorator, preconfigured to handle exceptions."""
errors = [httpx.RequestError, httpx.StreamError]
return create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)
def _is_valid_mistral_tool_call_id(tool_call_id: str) -> bool:
"""Check if tool call ID is nine character string consisting of a-z, A-Z, 0-9."""
return bool(TOOL_CALL_ID_PATTERN.match(tool_call_id))
def _base62_encode(num: int) -> str:
"""Encode a number in base62 and ensures result is of a specified length."""
base62 = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
if num == 0:
return base62[0]
arr = []
base = len(base62)
while num:
num, rem = divmod(num, base)
arr.append(base62[rem])
arr.reverse()
return "".join(arr)
def _convert_tool_call_id_to_mistral_compatible(tool_call_id: str) -> str:
"""Convert a tool call ID to a Mistral-compatible format."""
if _is_valid_mistral_tool_call_id(tool_call_id):
return tool_call_id
hash_bytes = hashlib.sha256(tool_call_id.encode()).digest()
hash_int = int.from_bytes(hash_bytes, byteorder="big")
base62_str = _base62_encode(hash_int)
if len(base62_str) >= 9:
return base62_str[:9]
return base62_str.rjust(9, "0")
def _convert_mistral_chat_message_to_message(
_message: dict,
) -> BaseMessage:
role = _message["role"]
if role != "assistant":
msg = f"Expected role to be 'assistant', got {role}"
raise ValueError(msg)
# Mistral returns None for tool invocations
content = _message.get("content", "") or ""
additional_kwargs: dict = {}
tool_calls = []
invalid_tool_calls = []
if raw_tool_calls := _message.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
for raw_tool_call in raw_tool_calls:
try:
parsed: dict = cast(
"dict", parse_tool_call(raw_tool_call, return_id=True)
)
if not parsed["id"]:
parsed["id"] = uuid.uuid4().hex[:]
tool_calls.append(parsed)
except Exception as e:
invalid_tool_calls.append(make_invalid_tool_call(raw_tool_call, str(e)))
return AIMessage(
content=content,
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
response_metadata={"model_provider": "mistralai"},
)
def _raise_on_error(response: httpx.Response) -> None:
"""Raise an error if the response is an error."""
if httpx.codes.is_error(response.status_code):
error_message = response.read().decode("utf-8")
msg = (
f"Error response {response.status_code} "
f"while fetching {response.url}: {error_message}"
)
raise httpx.HTTPStatusError(
msg,
request=response.request,
response=response,
)
async def _araise_on_error(response: httpx.Response) -> None:
"""Raise an error if the response is an error."""
if httpx.codes.is_error(response.status_code):
error_message = (await response.aread()).decode("utf-8")
msg = (
f"Error response {response.status_code} "
f"while fetching {response.url}: {error_message}"
)
raise httpx.HTTPStatusError(
msg,
request=response.request,
response=response,
)
async def _aiter_sse(
event_source_mgr: AbstractAsyncContextManager[EventSource],
) -> AsyncIterator[dict]:
"""Iterate over the server-sent events."""
async with event_source_mgr as event_source:
await _araise_on_error(event_source.response)
async for event in event_source.aiter_sse():
if event.data == "[DONE]":
return
yield event.json()
async def acompletion_with_retry(
llm: ChatMistralAI,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
if "stream" not in kwargs:
kwargs["stream"] = False
stream = kwargs["stream"]
if stream:
event_source = aconnect_sse(
llm.async_client, "POST", "/chat/completions", json=kwargs
)
return _aiter_sse(event_source)
response = await llm.async_client.post(url="/chat/completions", json=kwargs)
await _araise_on_error(response)
return response.json()
return await _completion_with_retry(**kwargs)
def _convert_chunk_to_message_chunk(
chunk: dict,
default_class: type[BaseMessageChunk],
index: int,
index_type: str,
output_version: str | None,
) -> tuple[BaseMessageChunk, int, str]:
_choice = chunk["choices"][0]
_delta = _choice["delta"]
role = _delta.get("role")
content = _delta.get("content") or ""
if output_version == "v1" and isinstance(content, str):
content = [{"type": "text", "text": content}]
if isinstance(content, list):
for block in content:
if isinstance(block, dict):
if "type" in block and block["type"] != index_type:
index_type = block["type"]
index = index + 1
if "index" not in block:
block["index"] = index
if block.get("type") == "thinking" and isinstance(
block.get("thinking"), list
):
for sub_block in block["thinking"]:
if isinstance(sub_block, dict) and "index" not in sub_block:
sub_block["index"] = 0
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content), index, index_type
if role == "assistant" or default_class == AIMessageChunk:
additional_kwargs: dict = {}
response_metadata = {}
if raw_tool_calls := _delta.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
try:
tool_call_chunks = []
for raw_tool_call in raw_tool_calls:
if not raw_tool_call.get("index") and not raw_tool_call.get("id"):
tool_call_id = uuid.uuid4().hex[:]
else:
tool_call_id = raw_tool_call.get("id")
tool_call_chunks.append(
tool_call_chunk(
name=raw_tool_call["function"].get("name"),
args=raw_tool_call["function"].get("arguments"),
id=tool_call_id,
index=raw_tool_call.get("index"),
)
)
except KeyError:
pass
else:
tool_call_chunks = []
if token_usage := chunk.get("usage"):
usage_metadata = {
"input_tokens": token_usage.get("prompt_tokens", 0),
"output_tokens": token_usage.get("completion_tokens", 0),
"total_tokens": token_usage.get("total_tokens", 0),
}
else:
usage_metadata = None
if _choice.get("finish_reason") is not None and isinstance(
chunk.get("model"), str
):
response_metadata["model_name"] = chunk["model"]
response_metadata["finish_reason"] = _choice["finish_reason"]
return (
AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
usage_metadata=usage_metadata, # type: ignore[arg-type]
response_metadata={"model_provider": "mistralai", **response_metadata},
),
index,
index_type,
)
if role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content), index, index_type
if role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role), index, index_type
return default_class(content=content), index, index_type # type: ignore[call-arg]
def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
"""Format LangChain ToolCall to dict expected by Mistral."""
result: dict[str, Any] = {
"function": {
"name": tool_call["name"],
"arguments": json.dumps(tool_call["args"], ensure_ascii=False),
}
}
if _id := tool_call.get("id"):
result["id"] = _convert_tool_call_id_to_mistral_compatible(_id)
return result
def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) -> dict:
"""Format LangChain InvalidToolCall to dict expected by Mistral."""
result: dict[str, Any] = {
"function": {
"name": invalid_tool_call["name"],
"arguments": invalid_tool_call["args"],
}
}
if _id := invalid_tool_call.get("id"):
result["id"] = _convert_tool_call_id_to_mistral_compatible(_id)
return result
def _clean_block(block: dict) -> dict:
# Remove "index" key added for message aggregation in langchain-core
new_block = {k: v for k, v in block.items() if k != "index"}
if block.get("type") == "thinking" and isinstance(block.get("thinking"), list):
new_block["thinking"] = [
(
{k: v for k, v in sb.items() if k != "index"}
if isinstance(sb, dict) and "index" in sb
else sb
)
for sb in block["thinking"]
]
return new_block
def _convert_message_to_mistral_chat_message(
message: BaseMessage,
) -> dict:
if isinstance(message, ChatMessage):
return {"role": message.role, "content": message.content}
if isinstance(message, HumanMessage):
return {"role": "user", "content": message.content}
if isinstance(message, AIMessage):
message_dict: dict[str, Any] = {"role": "assistant"}
tool_calls: list = []
if message.tool_calls or message.invalid_tool_calls:
if message.tool_calls:
tool_calls.extend(
_format_tool_call_for_mistral(tool_call)
for tool_call in message.tool_calls
)
if message.invalid_tool_calls:
tool_calls.extend(
_format_invalid_tool_call_for_mistral(invalid_tool_call)
for invalid_tool_call in message.invalid_tool_calls
)
elif "tool_calls" in message.additional_kwargs:
for tc in message.additional_kwargs["tool_calls"]:
chunk = {
"function": {
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"],
}
}
if _id := tc.get("id"):
chunk["id"] = _id
tool_calls.append(chunk)
else:
pass
if tool_calls: # do not populate empty list tool_calls
message_dict["tool_calls"] = tool_calls
# Message content
# Translate v1 content
if message.response_metadata.get("output_version") == "v1":
content = _convert_from_v1_to_mistral(
message.content_blocks, message.response_metadata.get("model_provider")
)
else:
content = message.content
if tool_calls and content:
# Assistant message must have either content or tool_calls, but not both.
# Some providers may not support tool_calls in the same message as content.
# This is done to ensure compatibility with messages from other providers.
content = ""
elif isinstance(content, list):
content = [
_clean_block(block)
if isinstance(block, dict) and "index" in block
else block
for block in content
]
else:
content = message.content
# if any blocks are dicts, cast strings to text blocks
if any(isinstance(block, dict) for block in content):
content = [
block if isinstance(block, dict) else {"type": "text", "text": block}
for block in content
]
message_dict["content"] = content
if "prefix" in message.additional_kwargs:
message_dict["prefix"] = message.additional_kwargs["prefix"]
return message_dict
if isinstance(message, SystemMessage):
return {"role": "system", "content": message.content}
if isinstance(message, ToolMessage):
return {
"role": "tool",
"content": message.content,
"name": message.name,
"tool_call_id": _convert_tool_call_id_to_mistral_compatible(
message.tool_call_id
),
}
msg = f"Got unknown type {message}"
raise ValueError(msg)
class ChatMistralAI(BaseChatModel):
"""A chat model that uses the Mistral AI API."""
# The type for client and async_client is ignored because the type is not
# an Optional after the model is initialized and the model_validator
# is run.
client: httpx.Client = Field( # type: ignore[assignment] # : meta private:
default=None, exclude=True
)
async_client: httpx.AsyncClient = Field( # type: ignore[assignment] # : meta private:
default=None, exclude=True
)
mistral_api_key: SecretStr | None = Field(
alias="api_key",
default_factory=secret_from_env("MISTRAL_API_KEY", default=None),
)
endpoint: str | None = Field(default=None, alias="base_url")
max_retries: int = 5
timeout: int = 120
max_concurrent_requests: int = 64
model: str = Field(default="mistral-small", alias="model_name")
temperature: float = 0.7
max_tokens: int | None = None
top_p: float = 1
"""Decode using nucleus sampling: consider the smallest set of tokens whose
probability sum is at least `top_p`. Must be in the closed interval
`[0.0, 1.0]`."""
random_seed: int | None = None
safe_mode: bool | None = None
streaming: bool = False
model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any invocation parameters not explicitly specified."""
model_config = ConfigDict(
populate_by_name=True,
arbitrary_types_allowed=True,
)
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
return _build_model_kwargs(values, all_required_field_names)
@property
def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling the API."""
defaults = {
"model": self.model,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"random_seed": self.random_seed,
"safe_prompt": self.safe_mode,
**self.model_kwargs,
}
return {k: v for k, v in defaults.items() if v is not None}
def _get_ls_params(
self, stop: list[str] | None = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs)
ls_params = LangSmithParams(
ls_provider="mistral",
ls_model_name=params.get("model", self.model),
ls_model_type="chat",
ls_temperature=params.get("temperature", self.temperature),
)
if ls_max_tokens := params.get("max_tokens", self.max_tokens):
ls_params["ls_max_tokens"] = ls_max_tokens
if ls_stop := stop or params.get("stop", None):
ls_params["ls_stop"] = ls_stop
return ls_params
@property
def _client_params(self) -> dict[str, Any]:
"""Get the parameters used for the client."""
return self._default_params
def completion_with_retry(
self, run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
if "stream" not in kwargs:
kwargs["stream"] = False
stream = kwargs["stream"]
if stream:
def iter_sse() -> Iterator[dict]:
with connect_sse(
self.client, "POST", "/chat/completions", json=kwargs
) as event_source:
_raise_on_error(event_source.response)
for event in event_source.iter_sse():
if event.data == "[DONE]":
return
yield event.json()
return iter_sse()
response = self.client.post(url="/chat/completions", json=kwargs)
_raise_on_error(response)
return response.json()
return _completion_with_retry(**kwargs)
def _combine_llm_outputs(self, llm_outputs: list[dict | None]) -> dict:
overall_token_usage: dict = {}
for output in llm_outputs:
if output is None:
# Happens in streaming
continue
token_usage = output["token_usage"]
if token_usage is not None:
for k, v in token_usage.items():
if k in overall_token_usage:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
return {"token_usage": overall_token_usage, "model_name": self.model}
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate api key, python package exists, temperature, and top_p."""
if isinstance(self.mistral_api_key, SecretStr):
api_key_str: str | None = self.mistral_api_key.get_secret_value()
else:
api_key_str = self.mistral_api_key
# TODO: handle retries
base_url_str = (
self.endpoint
or os.environ.get("MISTRAL_BASE_URL")
or "https://api.mistral.ai/v1"
)
self.endpoint = base_url_str
if not self.client:
self.client = httpx.Client(
base_url=base_url_str,
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key_str}",
},
timeout=self.timeout,
verify=global_ssl_context,
)
# TODO: handle retries and max_concurrency
if not self.async_client:
self.async_client = httpx.AsyncClient(
base_url=base_url_str,
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key_str}",
},
timeout=self.timeout,
verify=global_ssl_context,
)
if self.temperature is not None and not 0 <= self.temperature <= 1:
msg = "temperature must be in the range [0.0, 1.0]"
raise ValueError(msg)
if self.top_p is not None and not 0 <= self.top_p <= 1:
msg = "top_p must be in the range [0.0, 1.0]"
raise ValueError(msg)
return self
@model_validator(mode="after")
def _set_model_profile(self) -> Self:
"""Set model profile if not overridden."""
if self.profile is None:
self.profile = _get_default_model_profile(self.model)
return self
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
stream: bool | None = None, # noqa: FBT001
**kwargs: Any,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
def _create_chat_result(self, response: dict) -> ChatResult:
generations = []
token_usage = response.get("usage", {})
for res in response["choices"]:
finish_reason = res.get("finish_reason")
message = _convert_mistral_chat_message_to_message(res["message"])
if token_usage and isinstance(message, AIMessage):
message.usage_metadata = {
"input_tokens": token_usage.get("prompt_tokens", 0),
"output_tokens": token_usage.get("completion_tokens", 0),
"total_tokens": token_usage.get("total_tokens", 0),
}
gen = ChatGeneration(
message=message,
generation_info={"finish_reason": finish_reason},
)
generations.append(gen)
llm_output = {
"token_usage": token_usage,
"model_name": self.model,
"model": self.model, # Backwards compatibility
}
return ChatResult(generations=generations, llm_output=llm_output)
def _create_message_dicts(
self, messages: list[BaseMessage], stop: list[str] | None
) -> tuple[list[dict], dict[str, Any]]:
params = self._client_params
if stop is not None or "stop" in params:
if "stop" in params:
params.pop("stop")
logger.warning(
"Parameter `stop` not yet supported (https://docs.mistral.ai/api)"
)
message_dicts = [_convert_message_to_mistral_chat_message(m) for m in messages]
return message_dicts, params
def _stream(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
index = -1
index_type = ""
for chunk in self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
):
if len(chunk.get("choices", [])) == 0:
continue
new_chunk, index, index_type = _convert_chunk_to_message_chunk(
chunk, default_chunk_class, index, index_type, self.output_version
)
# make future chunks same type as first chunk
default_chunk_class = new_chunk.__class__
gen_chunk = ChatGenerationChunk(message=new_chunk)
if run_manager:
run_manager.on_llm_new_token(
token=cast("str", new_chunk.content), chunk=gen_chunk
)
yield gen_chunk
async def _astream(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
index = -1
index_type = ""
async for chunk in await acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params
):
if len(chunk.get("choices", [])) == 0:
continue
new_chunk, index, index_type = _convert_chunk_to_message_chunk(
chunk, default_chunk_class, index, index_type, self.output_version
)
# make future chunks same type as first chunk
default_chunk_class = new_chunk.__class__
gen_chunk = ChatGenerationChunk(message=new_chunk)
if run_manager:
await run_manager.on_llm_new_token(
token=cast("str", new_chunk.content), chunk=gen_chunk
)
yield gen_chunk
async def _agenerate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
stream: bool | None = None, # noqa: FBT001
**kwargs: Any,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = await acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
def bind_tools(
self,
tools: Sequence[dict[str, Any] | type | Callable | BaseTool],
tool_choice: dict | str | Literal["auto", "any"] | None = None, # noqa: PYI051
**kwargs: Any,
) -> Runnable[LanguageModelInput, AIMessage]:
"""Bind tool-like objects to this chat model.
Assumes model is compatible with OpenAI tool-calling API.
Args:
tools: A list of tool definitions to bind to this chat model.
Supports any tool definition handled by [`convert_to_openai_tool`][langchain_core.utils.function_calling.convert_to_openai_tool].
tool_choice: Which tool to require the model to call.
Must be the name of the single provided function or
`'auto'` to automatically determine which function to call
(if any), or a dict of the form:
{"type": "function", "function": {"name": <<tool_name>>}}.
kwargs: Any additional parameters are passed directly to
`self.bind(**kwargs)`.
""" # noqa: E501
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
if tool_choice:
tool_names = []
for tool in formatted_tools:
if ("function" in tool and (name := tool["function"].get("name"))) or (
name := tool.get("name")
):
tool_names.append(name)
else:
pass
if tool_choice in tool_names:
kwargs["tool_choice"] = {
"type": "function",
"function": {"name": tool_choice},
}
else:
kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs)
def with_structured_output(
self,
schema: dict | type | None = None,
*,
method: Literal[
"function_calling", "json_mode", "json_schema"
] = "function_calling",
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, dict | BaseModel]:
r"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema: The output schema. Can be passed in as:
- An OpenAI function/tool schema,
- A JSON Schema,
- A `TypedDict` class,
- Or a Pydantic class.
If `schema` is a Pydantic class then the model output will be a
Pydantic instance of that class, and the model-generated fields will be
validated by the Pydantic class. Otherwise the model output will be a
dict and will not be validated.
See `langchain_core.utils.function_calling.convert_to_openai_tool` for
more on how to properly specify types and descriptions of schema fields
when specifying a Pydantic or `TypedDict` class.
method: The method for steering model generation, one of:
- `'function_calling'`:
Uses Mistral's
[function-calling feature](https://docs.mistral.ai/capabilities/function_calling/).
- `'json_schema'`:
Uses Mistral's
[structured output feature](https://docs.mistral.ai/capabilities/structured-output/custom_structured_output/).
- `'json_mode'`:
Uses Mistral's
[JSON mode](https://docs.mistral.ai/capabilities/structured-output/json_mode/).
Note that if using JSON mode then you
must include instructions for formatting the output into the
desired schema into the model call.
!!! warning "Behavior changed in `langchain-mistralai` 0.2.5"
Added method="json_schema"
include_raw:
If `False` then only the parsed structured output is returned.
If an error occurs during model output parsing it will be raised.
If `True` then both the raw model response (a `BaseMessage`) and the
parsed model response will be returned.
If an error occurs during output parsing it will be caught and returned
as well.
The final output is always a `dict` with keys `'raw'`, `'parsed'`, and
`'parsing_error'`.
kwargs: Any additional parameters are passed directly to
`self.bind(**kwargs)`. This is useful for passing in
parameters such as `tool_choice` or `tools` to control
which tool the model should call, or to pass in parameters such as
`stop` to control when the model should stop generating output.
Returns:
A `Runnable` that takes same inputs as a
`langchain_core.language_models.chat.BaseChatModel`. If `include_raw` is
`False` and `schema` is a Pydantic class, `Runnable` outputs an instance
of `schema` (i.e., a Pydantic object). Otherwise, if `include_raw` is
`False` then `Runnable` outputs a `dict`.
If `include_raw` is `True`, then `Runnable` outputs a `dict` with keys:
- `'raw'`: `BaseMessage`
- `'parsed'`: `None` if there was a parsing error, otherwise the type
depends on the `schema` as described above.
- `'parsing_error'`: `BaseException | None`
Example: schema=Pydantic class, method="function_calling", include_raw=False:
```python
from typing import Optional
from langchain_mistralai import ChatMistralAI
from pydantic import BaseModel, Field
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
# If we provide default values and/or descriptions for fields, these will be passed
# to the model. This is an important part of improving a model's ability to
# correctly return structured outputs.
justification: str | None = Field(
default=None, description="A justification for the answer."
)
model = ChatMistralAI(model="mistral-large-latest", temperature=0)
structured_model = model.with_structured_output(AnswerWithJustification)
structured_model.invoke(
"What weighs more a pound of bricks or a pound of feathers"
)
# -> AnswerWithJustification(
# answer='They weigh the same',
# justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'
# )
```
Example: schema=Pydantic class, method="function_calling", include_raw=True:
```python
from langchain_mistralai import ChatMistralAI
from pydantic import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
model = ChatMistralAI(model="mistral-large-latest", temperature=0)
structured_model = model.with_structured_output(
AnswerWithJustification, include_raw=True
)
structured_model.invoke(
"What weighs more a pound of bricks or a pound of feathers"
)
# -> {
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),
# 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),
# 'parsing_error': None
# }
```
Example: schema=TypedDict class, method="function_calling", include_raw=False:
```python
from typing_extensions import Annotated, TypedDict
from langchain_mistralai import ChatMistralAI
class AnswerWithJustification(TypedDict):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: Annotated[
str | None, None, "A justification for the answer."
]
model = ChatMistralAI(model="mistral-large-latest", temperature=0)
structured_model = model.with_structured_output(AnswerWithJustification)
structured_model.invoke(
"What weighs more a pound of bricks or a pound of feathers"
)
# -> {
# 'answer': 'They weigh the same',
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
# }
```
Example: schema=OpenAI function schema, method="function_calling", include_raw=False:
```python
from langchain_mistralai import ChatMistralAI
oai_schema = {
'name': 'AnswerWithJustification',
'description': 'An answer to the user question along with justification for the answer.',
'parameters': {
'type': 'object',
'properties': {
'answer': {'type': 'string'},
'justification': {'description': 'A justification for the answer.', 'type': 'string'}
},
'required': ['answer']
}
model = ChatMistralAI(model="mistral-large-latest", temperature=0)
structured_model = model.with_structured_output(oai_schema)
structured_model.invoke(
"What weighs more a pound of bricks or a pound of feathers"
)
# -> {
# 'answer': 'They weigh the same',
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
# }
```
Example: schema=Pydantic class, method="json_mode", include_raw=True:
```python
from langchain_mistralai import ChatMistralAI
from pydantic import BaseModel
class AnswerWithJustification(BaseModel):
answer: str
justification: str
model = ChatMistralAI(model="mistral-large-latest", temperature=0)
structured_model = model.with_structured_output(
AnswerWithJustification, method="json_mode", include_raw=True
)
structured_model.invoke(
"Answer the following question. "
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\\n\\n"
"What's heavier a pound of bricks or a pound of feathers?"
)
# -> {
# 'raw': AIMessage(content='{\\n "answer": "They are both the same weight.",\\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \\n}'),
# 'parsed': AnswerWithJustification(answer='They are both the same weight.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'),
# 'parsing_error': None
# }
```
Example: schema=None, method="json_mode", include_raw=True:
```python
structured_model = model.with_structured_output(
method="json_mode", include_raw=True
)
structured_model.invoke(
"Answer the following question. "
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\\n\\n"
"What's heavier a pound of bricks or a pound of feathers?"
)
# -> {
# 'raw': AIMessage(content='{\\n "answer": "They are both the same weight.",\\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \\n}'),
# 'parsed': {
# 'answer': 'They are both the same weight.',
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'
# },
# 'parsing_error': None
# }
```
""" # noqa: E501
_ = kwargs.pop("strict", None)
if kwargs:
msg = f"Received unsupported arguments {kwargs}"
raise ValueError(msg)
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
if method == "function_calling":
if schema is None:
msg = (
"schema must be specified when method is 'function_calling'. "
"Received None."
)
raise ValueError(msg)
# TODO: Update to pass in tool name as tool_choice if/when Mistral supports
# specifying a tool.
llm = self.bind_tools(
[schema],
tool_choice="any",
ls_structured_output_format={
"kwargs": {"method": "function_calling"},
"schema": schema,
},
)
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], # type: ignore[list-item]
first_tool_only=True, # type: ignore[list-item]
)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
)
elif method == "json_mode":
llm = self.bind(
response_format={"type": "json_object"},
ls_structured_output_format={
"kwargs": {
# this is correct - name difference with mistral api
"method": "json_mode"
},
"schema": schema,
},
)
output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
if is_pydantic_schema
else JsonOutputParser()
)
elif method == "json_schema":
if schema is None:
msg = (
"schema must be specified when method is 'json_schema'. "
"Received None."
)
raise ValueError(msg)
response_format = _convert_to_openai_response_format(schema, strict=True)
llm = self.bind(
response_format=response_format,
ls_structured_output_format={
"kwargs": {"method": "json_schema"},
"schema": schema,
},
)
output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
if is_pydantic_schema
else JsonOutputParser()
)
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
return llm | output_parser
@property
def _identifying_params(self) -> dict[str, Any]:
"""Get the identifying parameters."""
return self._default_params
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "mistralai-chat"
@property
def lc_secrets(self) -> dict[str, str]:
return {"mistral_api_key": "MISTRAL_API_KEY"}
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by LangChain."""
return True
@classmethod
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the LangChain object.
Returns:
`["langchain", "chat_models", "mistralai"]`
"""
return ["langchain", "chat_models", "mistralai"]
def _convert_to_openai_response_format(
schema: dict[str, Any] | type, *, strict: bool | None = None
) -> dict:
"""Perform same op as in ChatOpenAI, but do not pass through Pydantic BaseModels."""
if (
isinstance(schema, dict)
and "json_schema" in schema
and schema.get("type") == "json_schema"
):
response_format = schema
elif isinstance(schema, dict) and "name" in schema and "schema" in schema:
response_format = {"type": "json_schema", "json_schema": schema}
else:
if strict is None:
if isinstance(schema, dict) and isinstance(schema.get("strict"), bool):
strict = schema["strict"]
else:
strict = False
function = convert_to_openai_tool(schema, strict=strict)["function"]
function["schema"] = function.pop("parameters")
response_format = {"type": "json_schema", "json_schema": function}
if (
strict is not None
and strict is not response_format["json_schema"].get("strict")
and isinstance(schema, dict)
):
msg = (
f"Output schema already has 'strict' value set to "
f"{schema['json_schema']['strict']} but 'strict' also passed in to "
f"with_structured_output as {strict}. Please make sure that "
f"'strict' is only specified in one place."
)
raise ValueError(msg)
return response_format