mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-28 10:39:23 +00:00
mistralai: support method="json_schema" in structured output (#29461)
https://docs.mistral.ai/capabilities/structured-output/custom_structured_output/
This commit is contained in:
parent
e120378695
commit
ca9d4e4595
@ -686,7 +686,9 @@ class ChatMistralAI(BaseChatModel):
|
||||
self,
|
||||
schema: Optional[Union[Dict, Type]] = None,
|
||||
*,
|
||||
method: Literal["function_calling", "json_mode"] = "function_calling",
|
||||
method: Literal[
|
||||
"function_calling", "json_mode", "json_schema"
|
||||
] = "function_calling",
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
@ -710,13 +712,25 @@ class ChatMistralAI(BaseChatModel):
|
||||
|
||||
Added support for TypedDict class.
|
||||
|
||||
method:
|
||||
The method for steering model generation, either "function_calling"
|
||||
or "json_mode". If "function_calling" then the schema will be converted
|
||||
to an OpenAI function and the returned model will make use of the
|
||||
function-calling API. If "json_mode" then OpenAI's JSON mode will be
|
||||
used. Note that if using "json_mode" then you must include instructions
|
||||
for formatting the output into the desired schema into the model call.
|
||||
method: The method for steering model generation, one of:
|
||||
|
||||
- "function_calling":
|
||||
Uses Mistral's
|
||||
`function-calling feature <https://docs.mistral.ai/capabilities/function_calling/>`_.
|
||||
- "json_schema":
|
||||
Uses Mistral's
|
||||
`structured output feature <https://docs.mistral.ai/capabilities/structured-output/custom_structured_output/>`_.
|
||||
- "json_mode":
|
||||
Uses Mistral's
|
||||
`JSON mode <https://docs.mistral.ai/capabilities/structured-output/json_mode/>`_.
|
||||
Note that if using JSON mode then you
|
||||
must include instructions for formatting the output into the
|
||||
desired schema into the model call.
|
||||
|
||||
.. versionchanged:: 0.2.5
|
||||
|
||||
Added method="json_schema"
|
||||
|
||||
include_raw:
|
||||
If False then only the parsed structured output is returned. If
|
||||
an error occurs during model output parsing it will be raised. If True
|
||||
@ -877,11 +891,11 @@ class ChatMistralAI(BaseChatModel):
|
||||
|
||||
structured_llm.invoke(
|
||||
"Answer the following question. "
|
||||
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n"
|
||||
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\\n\\n"
|
||||
"What's heavier a pound of bricks or a pound of feathers?"
|
||||
)
|
||||
# -> {
|
||||
# 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'),
|
||||
# 'raw': AIMessage(content='{\\n "answer": "They are both the same weight.",\\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \\n}'),
|
||||
# 'parsed': AnswerWithJustification(answer='They are both the same weight.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'),
|
||||
# 'parsing_error': None
|
||||
# }
|
||||
@ -893,17 +907,18 @@ class ChatMistralAI(BaseChatModel):
|
||||
|
||||
structured_llm.invoke(
|
||||
"Answer the following question. "
|
||||
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n"
|
||||
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\\n\\n"
|
||||
"What's heavier a pound of bricks or a pound of feathers?"
|
||||
)
|
||||
# -> {
|
||||
# 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'),
|
||||
# 'raw': AIMessage(content='{\\n "answer": "They are both the same weight.",\\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \\n}'),
|
||||
# 'parsed': {
|
||||
# 'answer': 'They are both the same weight.',
|
||||
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'
|
||||
# },
|
||||
# 'parsing_error': None
|
||||
# }
|
||||
|
||||
""" # noqa: E501
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
@ -934,6 +949,20 @@ class ChatMistralAI(BaseChatModel):
|
||||
if is_pydantic_schema
|
||||
else JsonOutputParser()
|
||||
)
|
||||
elif method == "json_schema":
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
"schema must be specified when method is 'json_schema'. "
|
||||
"Received None."
|
||||
)
|
||||
response_format = _convert_to_openai_response_format(schema, strict=True)
|
||||
llm = self.bind(response_format=response_format)
|
||||
|
||||
output_parser = (
|
||||
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
|
||||
if is_pydantic_schema
|
||||
else JsonOutputParser()
|
||||
)
|
||||
if include_raw:
|
||||
parser_assign = RunnablePassthrough.assign(
|
||||
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
|
||||
@ -969,3 +998,38 @@ class ChatMistralAI(BaseChatModel):
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "chat_models", "mistralai"]
|
||||
|
||||
|
||||
def _convert_to_openai_response_format(
|
||||
schema: Union[Dict[str, Any], Type], *, strict: Optional[bool] = None
|
||||
) -> Dict:
|
||||
"""Same as in ChatOpenAI, but don't pass through Pydantic BaseModels."""
|
||||
if (
|
||||
isinstance(schema, dict)
|
||||
and "json_schema" in schema
|
||||
and schema.get("type") == "json_schema"
|
||||
):
|
||||
response_format = schema
|
||||
elif isinstance(schema, dict) and "name" in schema and "schema" in schema:
|
||||
response_format = {"type": "json_schema", "json_schema": schema}
|
||||
else:
|
||||
if strict is None:
|
||||
if isinstance(schema, dict) and isinstance(schema.get("strict"), bool):
|
||||
strict = schema["strict"]
|
||||
else:
|
||||
strict = False
|
||||
function = convert_to_openai_tool(schema, strict=strict)["function"]
|
||||
function["schema"] = function.pop("parameters")
|
||||
response_format = {"type": "json_schema", "json_schema": function}
|
||||
|
||||
if strict is not None and strict is not response_format["json_schema"].get(
|
||||
"strict"
|
||||
):
|
||||
msg = (
|
||||
f"Output schema already has 'strict' value set to "
|
||||
f"{schema['json_schema']['strict']} but 'strict' also passed in to "
|
||||
f"with_structured_output as {strict}. Please make sure that "
|
||||
f"'strict' is only specified in one place."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return response_format
|
||||
|
@ -3,6 +3,7 @@
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
@ -10,6 +11,7 @@ from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_mistralai.chat_models import ChatMistralAI
|
||||
|
||||
@ -176,6 +178,65 @@ def test_streaming_structured_output() -> None:
|
||||
chunk_num += 1
|
||||
|
||||
|
||||
class Book(BaseModel):
|
||||
name: str
|
||||
authors: list[str]
|
||||
|
||||
|
||||
class BookDict(TypedDict):
|
||||
name: str
|
||||
authors: list[str]
|
||||
|
||||
|
||||
def _check_parsed_result(result: Any, schema: Any) -> None:
|
||||
if schema == Book:
|
||||
assert isinstance(result, Book)
|
||||
elif schema == BookDict:
|
||||
assert all(key in ["name", "authors"] for key in result.keys())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("schema", [Book, BookDict, Book.model_json_schema()])
|
||||
def test_structured_output_json_schema(schema: Any) -> None:
|
||||
llm = ChatMistralAI(model="ministral-8b-latest") # type: ignore[call-arg]
|
||||
structured_llm = llm.with_structured_output(schema, method="json_schema")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "Extract the book's information."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I recently read 'To Kill a Mockingbird' by Harper Lee.",
|
||||
},
|
||||
]
|
||||
# Test invoke
|
||||
result = structured_llm.invoke(messages)
|
||||
_check_parsed_result(result, schema)
|
||||
|
||||
# Test stream
|
||||
for chunk in structured_llm.stream(messages):
|
||||
_check_parsed_result(chunk, schema)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("schema", [Book, BookDict, Book.model_json_schema()])
|
||||
async def test_structured_output_json_schema_async(schema: Any) -> None:
|
||||
llm = ChatMistralAI(model="ministral-8b-latest") # type: ignore[call-arg]
|
||||
structured_llm = llm.with_structured_output(schema, method="json_schema")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "Extract the book's information."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I recently read 'To Kill a Mockingbird' by Harper Lee.",
|
||||
},
|
||||
]
|
||||
# Test invoke
|
||||
result = await structured_llm.ainvoke(messages)
|
||||
_check_parsed_result(result, schema)
|
||||
|
||||
# Test stream
|
||||
async for chunk in structured_llm.astream(messages):
|
||||
_check_parsed_result(chunk, schema)
|
||||
|
||||
|
||||
def test_tool_call() -> None:
|
||||
llm = ChatMistralAI(model="mistral-large-latest", temperature=0) # type: ignore[call-arg]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user