core, openai[patch]: support serialization of pydantic models in messages (#29940)

Resolves https://github.com/langchain-ai/langchain/issues/29003,
https://github.com/langchain-ai/langchain/issues/27264
Related: https://github.com/langchain-ai/langchain-redis/issues/52

```python
from langchain.chat_models import init_chat_model
from langchain.globals import set_llm_cache
from langchain_community.cache import SQLiteCache
from pydantic import BaseModel

cache = SQLiteCache()

set_llm_cache(cache)

class Temperature(BaseModel):
    value: int
    city: str

llm = init_chat_model("openai:gpt-4o-mini")
structured_llm = llm.with_structured_output(Temperature)
```
```python
# 681 ms
response = structured_llm.invoke("What is the average temperature of Rome in May?")
```
```python
# 6.98 ms
response = structured_llm.invoke("What is the average temperature of Rome in May?")
```
This commit is contained in:
ccurme 2025-02-24 09:34:27 -05:00 committed by GitHub
parent 1645ec1890
commit b1a7f4e106
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 82 additions and 9 deletions

View File

@ -1,6 +1,8 @@
import json
from typing import Any
from pydantic import BaseModel
from langchain_core.load.serializable import Serializable, to_json_not_implemented
@ -20,6 +22,23 @@ def default(obj: Any) -> Any:
return to_json_not_implemented(obj)
def _dump_pydantic_models(obj: Any) -> Any:
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration
if (
isinstance(obj, ChatGeneration)
and isinstance(obj.message, AIMessage)
and (parsed := obj.message.additional_kwargs.get("parsed"))
and isinstance(parsed, BaseModel)
):
obj_copy = obj.model_copy(deep=True)
obj_copy.message.additional_kwargs["parsed"] = parsed.model_dump()
return obj_copy
else:
return obj
def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str:
"""Return a json string representation of an object.
@ -40,6 +59,7 @@ def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str:
msg = "`default` should not be passed to dumps"
raise ValueError(msg)
try:
obj = _dump_pydantic_models(obj)
if pretty:
indent = kwargs.pop("indent", 2)
return json.dumps(obj, default=default, indent=indent, **kwargs)

View File

@ -1,7 +1,9 @@
from pydantic import ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field
from langchain_core.load import Serializable, dumpd, load
from langchain_core.load.serializable import _is_field_useful
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration
class NonBoolObj:
@ -203,3 +205,21 @@ def test_str() -> None:
non_bool=NonBoolObj(),
)
assert str(foo) == "content='str' non_bool=NonBoolObj"
def test_serialization_with_pydantic() -> None:
class MyModel(BaseModel):
x: int
y: str
my_model = MyModel(x=1, y="hello")
llm_response = ChatGeneration(
message=AIMessage(
content='{"x": 1, "y": "hello"}', additional_kwargs={"parsed": my_model}
)
)
ser = dumpd(llm_response)
deser = load(ser)
assert isinstance(deser, ChatGeneration)
assert deser.message.content
assert deser.message.additional_kwargs["parsed"] == my_model.model_dump()

View File

@ -9,6 +9,7 @@ import os
import re
import sys
import warnings
from functools import partial
from io import BytesIO
from math import ceil
from operator import itemgetter
@ -78,7 +79,12 @@ from langchain_core.output_parsers.openai_tools import (
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough, chain
from langchain_core.runnables import (
Runnable,
RunnableLambda,
RunnableMap,
RunnablePassthrough,
)
from langchain_core.runnables.config import run_in_executor
from langchain_core.tools import BaseTool
from langchain_core.utils import get_pydantic_field_names
@ -1436,9 +1442,9 @@ class BaseChatOpenAI(BaseChatModel):
},
)
if is_pydantic_schema:
output_parser = _oai_structured_outputs_parser.with_types(
output_type=cast(type, schema)
)
output_parser = RunnableLambda(
partial(_oai_structured_outputs_parser, schema=cast(type, schema))
).with_types(output_type=cast(type, schema))
else:
output_parser = JsonOutputParser()
else:
@ -2517,10 +2523,14 @@ def _convert_to_openai_response_format(
return response_format
@chain
def _oai_structured_outputs_parser(ai_msg: AIMessage) -> PydanticBaseModel:
if ai_msg.additional_kwargs.get("parsed"):
return ai_msg.additional_kwargs["parsed"]
def _oai_structured_outputs_parser(
ai_msg: AIMessage, schema: Type[_BM]
) -> PydanticBaseModel:
if parsed := ai_msg.additional_kwargs.get("parsed"):
if isinstance(parsed, dict):
return schema(**parsed)
else:
return parsed
elif ai_msg.additional_kwargs.get("refusal"):
raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"])
else:

View File

@ -1,11 +1,13 @@
"""Test OpenAI Chat API wrapper."""
import json
from functools import partial
from types import TracebackType
from typing import Any, Dict, List, Literal, Optional, Type, Union
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from langchain_core.load import dumps, loads
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
@ -17,6 +19,8 @@ from langchain_core.messages import (
ToolMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGeneration
from langchain_core.runnables import RunnableLambda
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
@ -27,6 +31,7 @@ from langchain_openai.chat_models.base import (
_convert_to_openai_response_format,
_create_usage_metadata,
_format_message_content,
_oai_structured_outputs_parser,
)
@ -913,3 +918,21 @@ def test_structured_output_old_model() -> None:
# assert tool calling was used instead of json_schema
assert "tools" in llm.steps[0].kwargs # type: ignore
assert "response_format" not in llm.steps[0].kwargs # type: ignore
def test_structured_outputs_parser() -> None:
parsed_response = GenerateUsername(name="alice", hair_color="black")
llm_output = ChatGeneration(
message=AIMessage(
content='{"name": "alice", "hair_color": "black"}',
additional_kwargs={"parsed": parsed_response},
)
)
output_parser = RunnableLambda(
partial(_oai_structured_outputs_parser, schema=GenerateUsername)
)
serialized = dumps(llm_output)
deserialized = loads(serialized)
assert isinstance(deserialized, ChatGeneration)
result = output_parser.invoke(deserialized.message)
assert result == parsed_response