mistralai: support method="json_schema" in structured output (#29461)

https://docs.mistral.ai/capabilities/structured-output/custom_structured_output/
This commit is contained in:
ccurme 2025-01-28 18:17:39 -05:00 committed by GitHub
parent e120378695
commit ca9d4e4595
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 137 additions and 12 deletions

View File

@ -686,7 +686,9 @@ class ChatMistralAI(BaseChatModel):
self, self,
schema: Optional[Union[Dict, Type]] = None, schema: Optional[Union[Dict, Type]] = None,
*, *,
method: Literal["function_calling", "json_mode"] = "function_calling", method: Literal[
"function_calling", "json_mode", "json_schema"
] = "function_calling",
include_raw: bool = False, include_raw: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
@ -710,13 +712,25 @@ class ChatMistralAI(BaseChatModel):
Added support for TypedDict class. Added support for TypedDict class.
method: method: The method for steering model generation, one of:
The method for steering model generation, either "function_calling"
or "json_mode". If "function_calling" then the schema will be converted - "function_calling":
to an OpenAI function and the returned model will make use of the Uses Mistral's
function-calling API. If "json_mode" then OpenAI's JSON mode will be `function-calling feature <https://docs.mistral.ai/capabilities/function_calling/>`_.
used. Note that if using "json_mode" then you must include instructions - "json_schema":
for formatting the output into the desired schema into the model call. Uses Mistral's
`structured output feature <https://docs.mistral.ai/capabilities/structured-output/custom_structured_output/>`_.
- "json_mode":
Uses Mistral's
`JSON mode <https://docs.mistral.ai/capabilities/structured-output/json_mode/>`_.
Note that if using JSON mode then you
must include instructions for formatting the output into the
desired schema into the model call.
.. versionchanged:: 0.2.5
Added method="json_schema"
include_raw: include_raw:
If False then only the parsed structured output is returned. If If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True an error occurs during model output parsing it will be raised. If True
@ -877,11 +891,11 @@ class ChatMistralAI(BaseChatModel):
structured_llm.invoke( structured_llm.invoke(
"Answer the following question. " "Answer the following question. "
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n" "Make sure to return a JSON blob with keys 'answer' and 'justification'.\\n\\n"
"What's heavier a pound of bricks or a pound of feathers?" "What's heavier a pound of bricks or a pound of feathers?"
) )
# -> { # -> {
# 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'), # 'raw': AIMessage(content='{\\n "answer": "They are both the same weight.",\\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \\n}'),
# 'parsed': AnswerWithJustification(answer='They are both the same weight.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'), # 'parsed': AnswerWithJustification(answer='They are both the same weight.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'),
# 'parsing_error': None # 'parsing_error': None
# } # }
@ -893,17 +907,18 @@ class ChatMistralAI(BaseChatModel):
structured_llm.invoke( structured_llm.invoke(
"Answer the following question. " "Answer the following question. "
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n" "Make sure to return a JSON blob with keys 'answer' and 'justification'.\\n\\n"
"What's heavier a pound of bricks or a pound of feathers?" "What's heavier a pound of bricks or a pound of feathers?"
) )
# -> { # -> {
# 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'), # 'raw': AIMessage(content='{\\n "answer": "They are both the same weight.",\\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \\n}'),
# 'parsed': { # 'parsed': {
# 'answer': 'They are both the same weight.', # 'answer': 'They are both the same weight.',
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.' # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'
# }, # },
# 'parsing_error': None # 'parsing_error': None
# } # }
""" # noqa: E501 """ # noqa: E501
if kwargs: if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}") raise ValueError(f"Received unsupported arguments {kwargs}")
@ -934,6 +949,20 @@ class ChatMistralAI(BaseChatModel):
if is_pydantic_schema if is_pydantic_schema
else JsonOutputParser() else JsonOutputParser()
) )
elif method == "json_schema":
if schema is None:
raise ValueError(
"schema must be specified when method is 'json_schema'. "
"Received None."
)
response_format = _convert_to_openai_response_format(schema, strict=True)
llm = self.bind(response_format=response_format)
output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
if is_pydantic_schema
else JsonOutputParser()
)
if include_raw: if include_raw:
parser_assign = RunnablePassthrough.assign( parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
@ -969,3 +998,38 @@ class ChatMistralAI(BaseChatModel):
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "chat_models", "mistralai"] return ["langchain", "chat_models", "mistralai"]
def _convert_to_openai_response_format(
schema: Union[Dict[str, Any], Type], *, strict: Optional[bool] = None
) -> Dict:
"""Same as in ChatOpenAI, but don't pass through Pydantic BaseModels."""
if (
isinstance(schema, dict)
and "json_schema" in schema
and schema.get("type") == "json_schema"
):
response_format = schema
elif isinstance(schema, dict) and "name" in schema and "schema" in schema:
response_format = {"type": "json_schema", "json_schema": schema}
else:
if strict is None:
if isinstance(schema, dict) and isinstance(schema.get("strict"), bool):
strict = schema["strict"]
else:
strict = False
function = convert_to_openai_tool(schema, strict=strict)["function"]
function["schema"] = function.pop("parameters")
response_format = {"type": "json_schema", "json_schema": function}
if strict is not None and strict is not response_format["json_schema"].get(
"strict"
):
msg = (
f"Output schema already has 'strict' value set to "
f"{schema['json_schema']['strict']} but 'strict' also passed in to "
f"with_structured_output as {strict}. Please make sure that "
f"'strict' is only specified in one place."
)
raise ValueError(msg)
return response_format

View File

@ -3,6 +3,7 @@
import json import json
from typing import Any, Optional from typing import Any, Optional
import pytest
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
AIMessageChunk, AIMessageChunk,
@ -10,6 +11,7 @@ from langchain_core.messages import (
HumanMessage, HumanMessage,
) )
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import TypedDict
from langchain_mistralai.chat_models import ChatMistralAI from langchain_mistralai.chat_models import ChatMistralAI
@ -176,6 +178,65 @@ def test_streaming_structured_output() -> None:
chunk_num += 1 chunk_num += 1
class Book(BaseModel):
name: str
authors: list[str]
class BookDict(TypedDict):
name: str
authors: list[str]
def _check_parsed_result(result: Any, schema: Any) -> None:
if schema == Book:
assert isinstance(result, Book)
elif schema == BookDict:
assert all(key in ["name", "authors"] for key in result.keys())
@pytest.mark.parametrize("schema", [Book, BookDict, Book.model_json_schema()])
def test_structured_output_json_schema(schema: Any) -> None:
llm = ChatMistralAI(model="ministral-8b-latest") # type: ignore[call-arg]
structured_llm = llm.with_structured_output(schema, method="json_schema")
messages = [
{"role": "system", "content": "Extract the book's information."},
{
"role": "user",
"content": "I recently read 'To Kill a Mockingbird' by Harper Lee.",
},
]
# Test invoke
result = structured_llm.invoke(messages)
_check_parsed_result(result, schema)
# Test stream
for chunk in structured_llm.stream(messages):
_check_parsed_result(chunk, schema)
@pytest.mark.parametrize("schema", [Book, BookDict, Book.model_json_schema()])
async def test_structured_output_json_schema_async(schema: Any) -> None:
llm = ChatMistralAI(model="ministral-8b-latest") # type: ignore[call-arg]
structured_llm = llm.with_structured_output(schema, method="json_schema")
messages = [
{"role": "system", "content": "Extract the book's information."},
{
"role": "user",
"content": "I recently read 'To Kill a Mockingbird' by Harper Lee.",
},
]
# Test invoke
result = await structured_llm.ainvoke(messages)
_check_parsed_result(result, schema)
# Test stream
async for chunk in structured_llm.astream(messages):
_check_parsed_result(chunk, schema)
def test_tool_call() -> None: def test_tool_call() -> None:
llm = ChatMistralAI(model="mistral-large-latest", temperature=0) # type: ignore[call-arg] llm = ChatMistralAI(model="mistral-large-latest", temperature=0) # type: ignore[call-arg]