mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 23:13:31 +00:00
mistralai: support method="json_schema" in structured output (#29461)
https://docs.mistral.ai/capabilities/structured-output/custom_structured_output/
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
@@ -10,6 +11,7 @@ from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_mistralai.chat_models import ChatMistralAI
|
||||
|
||||
@@ -176,6 +178,65 @@ def test_streaming_structured_output() -> None:
|
||||
chunk_num += 1
|
||||
|
||||
|
||||
class Book(BaseModel):
|
||||
name: str
|
||||
authors: list[str]
|
||||
|
||||
|
||||
class BookDict(TypedDict):
|
||||
name: str
|
||||
authors: list[str]
|
||||
|
||||
|
||||
def _check_parsed_result(result: Any, schema: Any) -> None:
|
||||
if schema == Book:
|
||||
assert isinstance(result, Book)
|
||||
elif schema == BookDict:
|
||||
assert all(key in ["name", "authors"] for key in result.keys())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("schema", [Book, BookDict, Book.model_json_schema()])
|
||||
def test_structured_output_json_schema(schema: Any) -> None:
|
||||
llm = ChatMistralAI(model="ministral-8b-latest") # type: ignore[call-arg]
|
||||
structured_llm = llm.with_structured_output(schema, method="json_schema")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "Extract the book's information."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I recently read 'To Kill a Mockingbird' by Harper Lee.",
|
||||
},
|
||||
]
|
||||
# Test invoke
|
||||
result = structured_llm.invoke(messages)
|
||||
_check_parsed_result(result, schema)
|
||||
|
||||
# Test stream
|
||||
for chunk in structured_llm.stream(messages):
|
||||
_check_parsed_result(chunk, schema)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("schema", [Book, BookDict, Book.model_json_schema()])
|
||||
async def test_structured_output_json_schema_async(schema: Any) -> None:
|
||||
llm = ChatMistralAI(model="ministral-8b-latest") # type: ignore[call-arg]
|
||||
structured_llm = llm.with_structured_output(schema, method="json_schema")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "Extract the book's information."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I recently read 'To Kill a Mockingbird' by Harper Lee.",
|
||||
},
|
||||
]
|
||||
# Test invoke
|
||||
result = await structured_llm.ainvoke(messages)
|
||||
_check_parsed_result(result, schema)
|
||||
|
||||
# Test stream
|
||||
async for chunk in structured_llm.astream(messages):
|
||||
_check_parsed_result(chunk, schema)
|
||||
|
||||
|
||||
def test_tool_call() -> None:
|
||||
llm = ChatMistralAI(model="mistral-large-latest", temperature=0) # type: ignore[call-arg]
|
||||
|
||||
|
Reference in New Issue
Block a user