diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py
index 686e4a7e6a8..78739a5fea1 100644
--- a/libs/partners/mistralai/langchain_mistralai/chat_models.py
+++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py
@@ -686,7 +686,9 @@ class ChatMistralAI(BaseChatModel):
self,
schema: Optional[Union[Dict, Type]] = None,
*,
- method: Literal["function_calling", "json_mode"] = "function_calling",
+ method: Literal[
+ "function_calling", "json_mode", "json_schema"
+ ] = "function_calling",
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
@@ -710,13 +712,25 @@ class ChatMistralAI(BaseChatModel):
Added support for TypedDict class.
- method:
- The method for steering model generation, either "function_calling"
- or "json_mode". If "function_calling" then the schema will be converted
- to an OpenAI function and the returned model will make use of the
- function-calling API. If "json_mode" then OpenAI's JSON mode will be
- used. Note that if using "json_mode" then you must include instructions
- for formatting the output into the desired schema into the model call.
+ method: The method for steering model generation, one of:
+
+ - "function_calling":
+ Uses Mistral's
+ `function-calling feature `_.
+ - "json_schema":
+ Uses Mistral's
+ `structured output feature `_.
+ - "json_mode":
+ Uses Mistral's
+ `JSON mode `_.
+ Note that if using JSON mode then you
+ must include instructions for formatting the output into the
+ desired schema into the model call.
+
+ .. versionchanged:: 0.2.5
+
+ Added method="json_schema"
+
include_raw:
If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
@@ -877,11 +891,11 @@ class ChatMistralAI(BaseChatModel):
structured_llm.invoke(
"Answer the following question. "
- "Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n"
+ "Make sure to return a JSON blob with keys 'answer' and 'justification'.\\n\\n"
"What's heavier a pound of bricks or a pound of feathers?"
)
# -> {
- # 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'),
+ # 'raw': AIMessage(content='{\\n "answer": "They are both the same weight.",\\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \\n}'),
# 'parsed': AnswerWithJustification(answer='They are both the same weight.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'),
# 'parsing_error': None
# }
@@ -893,17 +907,18 @@ class ChatMistralAI(BaseChatModel):
structured_llm.invoke(
"Answer the following question. "
- "Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n"
+ "Make sure to return a JSON blob with keys 'answer' and 'justification'.\\n\\n"
"What's heavier a pound of bricks or a pound of feathers?"
)
# -> {
- # 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'),
+ # 'raw': AIMessage(content='{\\n "answer": "They are both the same weight.",\\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \\n}'),
# 'parsed': {
# 'answer': 'They are both the same weight.',
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'
# },
# 'parsing_error': None
# }
+
""" # noqa: E501
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
@@ -934,6 +949,20 @@ class ChatMistralAI(BaseChatModel):
if is_pydantic_schema
else JsonOutputParser()
)
+ elif method == "json_schema":
+ if schema is None:
+ raise ValueError(
+ "schema must be specified when method is 'json_schema'. "
+ "Received None."
+ )
+ response_format = _convert_to_openai_response_format(schema, strict=True)
+ llm = self.bind(response_format=response_format)
+
+ output_parser = (
+ PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
+ if is_pydantic_schema
+ else JsonOutputParser()
+ )
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
@@ -969,3 +998,38 @@ class ChatMistralAI(BaseChatModel):
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "mistralai"]
+
+
+def _convert_to_openai_response_format(
+ schema: Union[Dict[str, Any], Type], *, strict: Optional[bool] = None
+) -> Dict:
+ """Same as in ChatOpenAI, but don't pass through Pydantic BaseModels."""
+ if (
+ isinstance(schema, dict)
+ and "json_schema" in schema
+ and schema.get("type") == "json_schema"
+ ):
+ response_format = schema
+ elif isinstance(schema, dict) and "name" in schema and "schema" in schema:
+ response_format = {"type": "json_schema", "json_schema": schema}
+ else:
+ if strict is None:
+ if isinstance(schema, dict) and isinstance(schema.get("strict"), bool):
+ strict = schema["strict"]
+ else:
+ strict = False
+ function = convert_to_openai_tool(schema, strict=strict)["function"]
+ function["schema"] = function.pop("parameters")
+ response_format = {"type": "json_schema", "json_schema": function}
+
+ if strict is not None and strict is not response_format["json_schema"].get(
+ "strict"
+ ):
+ msg = (
+ f"Output schema already has 'strict' value set to "
+ f"{schema['json_schema']['strict']} but 'strict' also passed in to "
+ f"with_structured_output as {strict}. Please make sure that "
+ f"'strict' is only specified in one place."
+ )
+ raise ValueError(msg)
+ return response_format
diff --git a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py
index dc67bd93abe..f3ce39bcebf 100644
--- a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py
+++ b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py
@@ -3,6 +3,7 @@
import json
from typing import Any, Optional
+import pytest
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
@@ -10,6 +11,7 @@ from langchain_core.messages import (
HumanMessage,
)
from pydantic import BaseModel
+from typing_extensions import TypedDict
from langchain_mistralai.chat_models import ChatMistralAI
@@ -176,6 +178,65 @@ def test_streaming_structured_output() -> None:
chunk_num += 1
+class Book(BaseModel):
+ name: str
+ authors: list[str]
+
+
+class BookDict(TypedDict):
+ name: str
+ authors: list[str]
+
+
+def _check_parsed_result(result: Any, schema: Any) -> None:
+ if schema == Book:
+ assert isinstance(result, Book)
+ elif schema == BookDict:
+ assert all(key in ["name", "authors"] for key in result.keys())
+
+
+@pytest.mark.parametrize("schema", [Book, BookDict, Book.model_json_schema()])
+def test_structured_output_json_schema(schema: Any) -> None:
+ llm = ChatMistralAI(model="ministral-8b-latest") # type: ignore[call-arg]
+ structured_llm = llm.with_structured_output(schema, method="json_schema")
+
+ messages = [
+ {"role": "system", "content": "Extract the book's information."},
+ {
+ "role": "user",
+ "content": "I recently read 'To Kill a Mockingbird' by Harper Lee.",
+ },
+ ]
+ # Test invoke
+ result = structured_llm.invoke(messages)
+ _check_parsed_result(result, schema)
+
+ # Test stream
+ for chunk in structured_llm.stream(messages):
+ _check_parsed_result(chunk, schema)
+
+
+@pytest.mark.parametrize("schema", [Book, BookDict, Book.model_json_schema()])
+async def test_structured_output_json_schema_async(schema: Any) -> None:
+ llm = ChatMistralAI(model="ministral-8b-latest") # type: ignore[call-arg]
+ structured_llm = llm.with_structured_output(schema, method="json_schema")
+
+ messages = [
+ {"role": "system", "content": "Extract the book's information."},
+ {
+ "role": "user",
+ "content": "I recently read 'To Kill a Mockingbird' by Harper Lee.",
+ },
+ ]
+ # Test invoke
+ result = await structured_llm.ainvoke(messages)
+ _check_parsed_result(result, schema)
+
+ # Test stream
+ async for chunk in structured_llm.astream(messages):
+ _check_parsed_result(chunk, schema)
+
+
def test_tool_call() -> None:
llm = ChatMistralAI(model="mistral-large-latest", temperature=0) # type: ignore[call-arg]