core, openai[patch]: support serialization of pydantic models in messages (#29940)

Resolves https://github.com/langchain-ai/langchain/issues/29003,
https://github.com/langchain-ai/langchain/issues/27264
Related: https://github.com/langchain-ai/langchain-redis/issues/52

```python
from langchain.chat_models import init_chat_model
from langchain.globals import set_llm_cache
from langchain_community.cache import SQLiteCache
from pydantic import BaseModel

cache = SQLiteCache()

set_llm_cache(cache)

class Temperature(BaseModel):
    value: int
    city: str

llm = init_chat_model("openai:gpt-4o-mini")
structured_llm = llm.with_structured_output(Temperature)
```
```python
# 681 ms
response = structured_llm.invoke("What is the average temperature of Rome in May?")
```
```python
# 6.98 ms
response = structured_llm.invoke("What is the average temperature of Rome in May?")
```
This commit is contained in:
ccurme 2025-02-24 09:34:27 -05:00 committed by GitHub
parent 1645ec1890
commit b1a7f4e106
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 82 additions and 9 deletions

View File

@ -1,6 +1,8 @@
import json import json
from typing import Any from typing import Any
from pydantic import BaseModel
from langchain_core.load.serializable import Serializable, to_json_not_implemented from langchain_core.load.serializable import Serializable, to_json_not_implemented
@ -20,6 +22,23 @@ def default(obj: Any) -> Any:
return to_json_not_implemented(obj) return to_json_not_implemented(obj)
def _dump_pydantic_models(obj: Any) -> Any:
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration
if (
isinstance(obj, ChatGeneration)
and isinstance(obj.message, AIMessage)
and (parsed := obj.message.additional_kwargs.get("parsed"))
and isinstance(parsed, BaseModel)
):
obj_copy = obj.model_copy(deep=True)
obj_copy.message.additional_kwargs["parsed"] = parsed.model_dump()
return obj_copy
else:
return obj
def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str: def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str:
"""Return a json string representation of an object. """Return a json string representation of an object.
@ -40,6 +59,7 @@ def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str:
msg = "`default` should not be passed to dumps" msg = "`default` should not be passed to dumps"
raise ValueError(msg) raise ValueError(msg)
try: try:
obj = _dump_pydantic_models(obj)
if pretty: if pretty:
indent = kwargs.pop("indent", 2) indent = kwargs.pop("indent", 2)
return json.dumps(obj, default=default, indent=indent, **kwargs) return json.dumps(obj, default=default, indent=indent, **kwargs)

View File

@ -1,7 +1,9 @@
from pydantic import ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from langchain_core.load import Serializable, dumpd, load from langchain_core.load import Serializable, dumpd, load
from langchain_core.load.serializable import _is_field_useful from langchain_core.load.serializable import _is_field_useful
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration
class NonBoolObj: class NonBoolObj:
@ -203,3 +205,21 @@ def test_str() -> None:
non_bool=NonBoolObj(), non_bool=NonBoolObj(),
) )
assert str(foo) == "content='str' non_bool=NonBoolObj" assert str(foo) == "content='str' non_bool=NonBoolObj"
def test_serialization_with_pydantic() -> None:
class MyModel(BaseModel):
x: int
y: str
my_model = MyModel(x=1, y="hello")
llm_response = ChatGeneration(
message=AIMessage(
content='{"x": 1, "y": "hello"}', additional_kwargs={"parsed": my_model}
)
)
ser = dumpd(llm_response)
deser = load(ser)
assert isinstance(deser, ChatGeneration)
assert deser.message.content
assert deser.message.additional_kwargs["parsed"] == my_model.model_dump()

View File

@ -9,6 +9,7 @@ import os
import re import re
import sys import sys
import warnings import warnings
from functools import partial
from io import BytesIO from io import BytesIO
from math import ceil from math import ceil
from operator import itemgetter from operator import itemgetter
@ -78,7 +79,12 @@ from langchain_core.output_parsers.openai_tools import (
parse_tool_call, parse_tool_call,
) )
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough, chain from langchain_core.runnables import (
Runnable,
RunnableLambda,
RunnableMap,
RunnablePassthrough,
)
from langchain_core.runnables.config import run_in_executor from langchain_core.runnables.config import run_in_executor
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils import get_pydantic_field_names from langchain_core.utils import get_pydantic_field_names
@ -1436,9 +1442,9 @@ class BaseChatOpenAI(BaseChatModel):
}, },
) )
if is_pydantic_schema: if is_pydantic_schema:
output_parser = _oai_structured_outputs_parser.with_types( output_parser = RunnableLambda(
output_type=cast(type, schema) partial(_oai_structured_outputs_parser, schema=cast(type, schema))
) ).with_types(output_type=cast(type, schema))
else: else:
output_parser = JsonOutputParser() output_parser = JsonOutputParser()
else: else:
@ -2517,10 +2523,14 @@ def _convert_to_openai_response_format(
return response_format return response_format
@chain def _oai_structured_outputs_parser(
def _oai_structured_outputs_parser(ai_msg: AIMessage) -> PydanticBaseModel: ai_msg: AIMessage, schema: Type[_BM]
if ai_msg.additional_kwargs.get("parsed"): ) -> PydanticBaseModel:
return ai_msg.additional_kwargs["parsed"] if parsed := ai_msg.additional_kwargs.get("parsed"):
if isinstance(parsed, dict):
return schema(**parsed)
else:
return parsed
elif ai_msg.additional_kwargs.get("refusal"): elif ai_msg.additional_kwargs.get("refusal"):
raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"]) raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"])
else: else:

View File

@ -1,11 +1,13 @@
"""Test OpenAI Chat API wrapper.""" """Test OpenAI Chat API wrapper."""
import json import json
from functools import partial
from types import TracebackType from types import TracebackType
from typing import Any, Dict, List, Literal, Optional, Type, Union from typing import Any, Dict, List, Literal, Optional, Type, Union
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from langchain_core.load import dumps, loads
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
AIMessageChunk, AIMessageChunk,
@ -17,6 +19,8 @@ from langchain_core.messages import (
ToolMessage, ToolMessage,
) )
from langchain_core.messages.ai import UsageMetadata from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGeneration
from langchain_core.runnables import RunnableLambda
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import TypedDict from typing_extensions import TypedDict
@ -27,6 +31,7 @@ from langchain_openai.chat_models.base import (
_convert_to_openai_response_format, _convert_to_openai_response_format,
_create_usage_metadata, _create_usage_metadata,
_format_message_content, _format_message_content,
_oai_structured_outputs_parser,
) )
@ -913,3 +918,21 @@ def test_structured_output_old_model() -> None:
# assert tool calling was used instead of json_schema # assert tool calling was used instead of json_schema
assert "tools" in llm.steps[0].kwargs # type: ignore assert "tools" in llm.steps[0].kwargs # type: ignore
assert "response_format" not in llm.steps[0].kwargs # type: ignore assert "response_format" not in llm.steps[0].kwargs # type: ignore
def test_structured_outputs_parser() -> None:
parsed_response = GenerateUsername(name="alice", hair_color="black")
llm_output = ChatGeneration(
message=AIMessage(
content='{"name": "alice", "hair_color": "black"}',
additional_kwargs={"parsed": parsed_response},
)
)
output_parser = RunnableLambda(
partial(_oai_structured_outputs_parser, schema=GenerateUsername)
)
serialized = dumps(llm_output)
deserialized = loads(serialized)
assert isinstance(deserialized, ChatGeneration)
result = output_parser.invoke(deserialized.message)
assert result == parsed_response