diff --git a/libs/core/langchain_core/load/dump.py b/libs/core/langchain_core/load/dump.py index 00fae99d528..7d878b7fd4b 100644 --- a/libs/core/langchain_core/load/dump.py +++ b/libs/core/langchain_core/load/dump.py @@ -1,6 +1,8 @@ import json from typing import Any +from pydantic import BaseModel + from langchain_core.load.serializable import Serializable, to_json_not_implemented @@ -20,6 +22,23 @@ def default(obj: Any) -> Any: return to_json_not_implemented(obj) +def _dump_pydantic_models(obj: Any) -> Any: + from langchain_core.messages import AIMessage + from langchain_core.outputs import ChatGeneration + + if ( + isinstance(obj, ChatGeneration) + and isinstance(obj.message, AIMessage) + and (parsed := obj.message.additional_kwargs.get("parsed")) + and isinstance(parsed, BaseModel) + ): + obj_copy = obj.model_copy(deep=True) + obj_copy.message.additional_kwargs["parsed"] = parsed.model_dump() + return obj_copy + else: + return obj + + def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str: """Return a json string representation of an object. @@ -40,6 +59,7 @@ def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str: msg = "`default` should not be passed to dumps" raise ValueError(msg) try: + obj = _dump_pydantic_models(obj) if pretty: indent = kwargs.pop("indent", 2) return json.dumps(obj, default=default, indent=indent, **kwargs) diff --git a/libs/core/tests/unit_tests/load/test_serializable.py b/libs/core/tests/unit_tests/load/test_serializable.py index 1c8b6772f09..b66369fc02f 100644 --- a/libs/core/tests/unit_tests/load/test_serializable.py +++ b/libs/core/tests/unit_tests/load/test_serializable.py @@ -1,7 +1,9 @@ -from pydantic import ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field from langchain_core.load import Serializable, dumpd, load from langchain_core.load.serializable import _is_field_useful +from langchain_core.messages import AIMessage +from langchain_core.outputs import ChatGeneration class NonBoolObj: @@ -203,3 +205,21 @@ def test_str() -> None: non_bool=NonBoolObj(), ) assert str(foo) == "content='str' non_bool=NonBoolObj" + + +def test_serialization_with_pydantic() -> None: + class MyModel(BaseModel): + x: int + y: str + + my_model = MyModel(x=1, y="hello") + llm_response = ChatGeneration( + message=AIMessage( + content='{"x": 1, "y": "hello"}', additional_kwargs={"parsed": my_model} + ) + ) + ser = dumpd(llm_response) + deser = load(ser) + assert isinstance(deser, ChatGeneration) + assert deser.message.content + assert deser.message.additional_kwargs["parsed"] == my_model.model_dump() diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index a264fdedba6..3c31ab2b57c 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -9,6 +9,7 @@ import os import re import sys import warnings +from functools import partial from io import BytesIO from math import ceil from operator import itemgetter @@ -78,7 +79,12 @@ from langchain_core.output_parsers.openai_tools import ( parse_tool_call, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough, chain +from langchain_core.runnables import ( + Runnable, + RunnableLambda, + RunnableMap, + RunnablePassthrough, +) from langchain_core.runnables.config import run_in_executor from langchain_core.tools import BaseTool from langchain_core.utils import get_pydantic_field_names @@ -1436,9 +1442,9 @@ class BaseChatOpenAI(BaseChatModel): }, ) if is_pydantic_schema: - output_parser = _oai_structured_outputs_parser.with_types( - output_type=cast(type, schema) - ) + output_parser = RunnableLambda( + partial(_oai_structured_outputs_parser, schema=cast(type, schema)) + ).with_types(output_type=cast(type, schema)) else: output_parser = JsonOutputParser() else: @@ -2517,10 +2523,14 @@ def _convert_to_openai_response_format( return response_format -@chain -def _oai_structured_outputs_parser(ai_msg: AIMessage) -> PydanticBaseModel: - if ai_msg.additional_kwargs.get("parsed"): - return ai_msg.additional_kwargs["parsed"] +def _oai_structured_outputs_parser( + ai_msg: AIMessage, schema: Type[_BM] +) -> PydanticBaseModel: + if parsed := ai_msg.additional_kwargs.get("parsed"): + if isinstance(parsed, dict): + return schema(**parsed) + else: + return parsed elif ai_msg.additional_kwargs.get("refusal"): raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"]) else: diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 5f129ef8d37..8f8c6fa0361 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -1,11 +1,13 @@ """Test OpenAI Chat API wrapper.""" import json +from functools import partial from types import TracebackType from typing import Any, Dict, List, Literal, Optional, Type, Union from unittest.mock import AsyncMock, MagicMock, patch import pytest +from langchain_core.load import dumps, loads from langchain_core.messages import ( AIMessage, AIMessageChunk, @@ -17,6 +19,8 @@ from langchain_core.messages import ( ToolMessage, ) from langchain_core.messages.ai import UsageMetadata +from langchain_core.outputs import ChatGeneration +from langchain_core.runnables import RunnableLambda from pydantic import BaseModel, Field from typing_extensions import TypedDict @@ -27,6 +31,7 @@ from langchain_openai.chat_models.base import ( _convert_to_openai_response_format, _create_usage_metadata, _format_message_content, + _oai_structured_outputs_parser, ) @@ -913,3 +918,21 @@ def test_structured_output_old_model() -> None: # assert tool calling was used instead of json_schema assert "tools" in llm.steps[0].kwargs # type: ignore assert "response_format" not in llm.steps[0].kwargs # type: ignore + + +def test_structured_outputs_parser() -> None: + parsed_response = GenerateUsername(name="alice", hair_color="black") + llm_output = ChatGeneration( + message=AIMessage( + content='{"name": "alice", "hair_color": "black"}', + additional_kwargs={"parsed": parsed_response}, + ) + ) + output_parser = RunnableLambda( + partial(_oai_structured_outputs_parser, schema=GenerateUsername) + ) + serialized = dumps(llm_output) + deserialized = loads(serialized) + assert isinstance(deserialized, ChatGeneration) + result = output_parser.invoke(deserialized.message) + assert result == parsed_response