diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 686e4a7e6a8..78739a5fea1 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -686,7 +686,9 @@ class ChatMistralAI(BaseChatModel): self, schema: Optional[Union[Dict, Type]] = None, *, - method: Literal["function_calling", "json_mode"] = "function_calling", + method: Literal[ + "function_calling", "json_mode", "json_schema" + ] = "function_calling", include_raw: bool = False, **kwargs: Any, ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: @@ -710,13 +712,25 @@ class ChatMistralAI(BaseChatModel): Added support for TypedDict class. - method: - The method for steering model generation, either "function_calling" - or "json_mode". If "function_calling" then the schema will be converted - to an OpenAI function and the returned model will make use of the - function-calling API. If "json_mode" then OpenAI's JSON mode will be - used. Note that if using "json_mode" then you must include instructions - for formatting the output into the desired schema into the model call. + method: The method for steering model generation, one of: + + - "function_calling": + Uses Mistral's + `function-calling feature `_. + - "json_schema": + Uses Mistral's + `structured output feature `_. + - "json_mode": + Uses Mistral's + `JSON mode `_. + Note that if using JSON mode then you + must include instructions for formatting the output into the + desired schema into the model call. + + .. versionchanged:: 0.2.5 + + Added method="json_schema" + include_raw: If False then only the parsed structured output is returned. If an error occurs during model output parsing it will be raised. If True @@ -877,11 +891,11 @@ class ChatMistralAI(BaseChatModel): structured_llm.invoke( "Answer the following question. " - "Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n" + "Make sure to return a JSON blob with keys 'answer' and 'justification'.\\n\\n" "What's heavier a pound of bricks or a pound of feathers?" ) # -> { - # 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'), + # 'raw': AIMessage(content='{\\n "answer": "They are both the same weight.",\\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \\n}'), # 'parsed': AnswerWithJustification(answer='They are both the same weight.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'), # 'parsing_error': None # } @@ -893,17 +907,18 @@ class ChatMistralAI(BaseChatModel): structured_llm.invoke( "Answer the following question. " - "Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n" + "Make sure to return a JSON blob with keys 'answer' and 'justification'.\\n\\n" "What's heavier a pound of bricks or a pound of feathers?" ) # -> { - # 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'), + # 'raw': AIMessage(content='{\\n "answer": "They are both the same weight.",\\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \\n}'), # 'parsed': { # 'answer': 'They are both the same weight.', # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.' # }, # 'parsing_error': None # } + """ # noqa: E501 if kwargs: raise ValueError(f"Received unsupported arguments {kwargs}") @@ -934,6 +949,20 @@ class ChatMistralAI(BaseChatModel): if is_pydantic_schema else JsonOutputParser() ) + elif method == "json_schema": + if schema is None: + raise ValueError( + "schema must be specified when method is 'json_schema'. " + "Received None." + ) + response_format = _convert_to_openai_response_format(schema, strict=True) + llm = self.bind(response_format=response_format) + + output_parser = ( + PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] + if is_pydantic_schema + else JsonOutputParser() + ) if include_raw: parser_assign = RunnablePassthrough.assign( parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None @@ -969,3 +998,38 @@ class ChatMistralAI(BaseChatModel): def get_lc_namespace(cls) -> List[str]: """Get the namespace of the langchain object.""" return ["langchain", "chat_models", "mistralai"] + + +def _convert_to_openai_response_format( + schema: Union[Dict[str, Any], Type], *, strict: Optional[bool] = None +) -> Dict: + """Same as in ChatOpenAI, but don't pass through Pydantic BaseModels.""" + if ( + isinstance(schema, dict) + and "json_schema" in schema + and schema.get("type") == "json_schema" + ): + response_format = schema + elif isinstance(schema, dict) and "name" in schema and "schema" in schema: + response_format = {"type": "json_schema", "json_schema": schema} + else: + if strict is None: + if isinstance(schema, dict) and isinstance(schema.get("strict"), bool): + strict = schema["strict"] + else: + strict = False + function = convert_to_openai_tool(schema, strict=strict)["function"] + function["schema"] = function.pop("parameters") + response_format = {"type": "json_schema", "json_schema": function} + + if strict is not None and strict is not response_format["json_schema"].get( + "strict" + ): + msg = ( + f"Output schema already has 'strict' value set to " + f"{schema['json_schema']['strict']} but 'strict' also passed in to " + f"with_structured_output as {strict}. Please make sure that " + f"'strict' is only specified in one place." + ) + raise ValueError(msg) + return response_format diff --git a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py index dc67bd93abe..f3ce39bcebf 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py @@ -3,6 +3,7 @@ import json from typing import Any, Optional +import pytest from langchain_core.messages import ( AIMessage, AIMessageChunk, @@ -10,6 +11,7 @@ from langchain_core.messages import ( HumanMessage, ) from pydantic import BaseModel +from typing_extensions import TypedDict from langchain_mistralai.chat_models import ChatMistralAI @@ -176,6 +178,65 @@ def test_streaming_structured_output() -> None: chunk_num += 1 +class Book(BaseModel): + name: str + authors: list[str] + + +class BookDict(TypedDict): + name: str + authors: list[str] + + +def _check_parsed_result(result: Any, schema: Any) -> None: + if schema == Book: + assert isinstance(result, Book) + elif schema == BookDict: + assert all(key in ["name", "authors"] for key in result.keys()) + + +@pytest.mark.parametrize("schema", [Book, BookDict, Book.model_json_schema()]) +def test_structured_output_json_schema(schema: Any) -> None: + llm = ChatMistralAI(model="ministral-8b-latest") # type: ignore[call-arg] + structured_llm = llm.with_structured_output(schema, method="json_schema") + + messages = [ + {"role": "system", "content": "Extract the book's information."}, + { + "role": "user", + "content": "I recently read 'To Kill a Mockingbird' by Harper Lee.", + }, + ] + # Test invoke + result = structured_llm.invoke(messages) + _check_parsed_result(result, schema) + + # Test stream + for chunk in structured_llm.stream(messages): + _check_parsed_result(chunk, schema) + + +@pytest.mark.parametrize("schema", [Book, BookDict, Book.model_json_schema()]) +async def test_structured_output_json_schema_async(schema: Any) -> None: + llm = ChatMistralAI(model="ministral-8b-latest") # type: ignore[call-arg] + structured_llm = llm.with_structured_output(schema, method="json_schema") + + messages = [ + {"role": "system", "content": "Extract the book's information."}, + { + "role": "user", + "content": "I recently read 'To Kill a Mockingbird' by Harper Lee.", + }, + ] + # Test invoke + result = await structured_llm.ainvoke(messages) + _check_parsed_result(result, schema) + + # Test stream + async for chunk in structured_llm.astream(messages): + _check_parsed_result(chunk, schema) + + def test_tool_call() -> None: llm = ChatMistralAI(model="mistral-large-latest", temperature=0) # type: ignore[call-arg]