mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
core, openai[patch]: support serialization of pydantic models in messages (#29940)
Resolves https://github.com/langchain-ai/langchain/issues/29003, https://github.com/langchain-ai/langchain/issues/27264 Related: https://github.com/langchain-ai/langchain-redis/issues/52 ```python from langchain.chat_models import init_chat_model from langchain.globals import set_llm_cache from langchain_community.cache import SQLiteCache from pydantic import BaseModel cache = SQLiteCache() set_llm_cache(cache) class Temperature(BaseModel): value: int city: str llm = init_chat_model("openai:gpt-4o-mini") structured_llm = llm.with_structured_output(Temperature) ``` ```python # 681 ms response = structured_llm.invoke("What is the average temperature of Rome in May?") ``` ```python # 6.98 ms response = structured_llm.invoke("What is the average temperature of Rome in May?") ```
This commit is contained in:
parent
1645ec1890
commit
b1a7f4e106
@ -1,6 +1,8 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.load.serializable import Serializable, to_json_not_implemented
|
||||
|
||||
|
||||
@ -20,6 +22,23 @@ def default(obj: Any) -> Any:
|
||||
return to_json_not_implemented(obj)
|
||||
|
||||
|
||||
def _dump_pydantic_models(obj: Any) -> Any:
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
|
||||
if (
|
||||
isinstance(obj, ChatGeneration)
|
||||
and isinstance(obj.message, AIMessage)
|
||||
and (parsed := obj.message.additional_kwargs.get("parsed"))
|
||||
and isinstance(parsed, BaseModel)
|
||||
):
|
||||
obj_copy = obj.model_copy(deep=True)
|
||||
obj_copy.message.additional_kwargs["parsed"] = parsed.model_dump()
|
||||
return obj_copy
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str:
|
||||
"""Return a json string representation of an object.
|
||||
|
||||
@ -40,6 +59,7 @@ def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str:
|
||||
msg = "`default` should not be passed to dumps"
|
||||
raise ValueError(msg)
|
||||
try:
|
||||
obj = _dump_pydantic_models(obj)
|
||||
if pretty:
|
||||
indent = kwargs.pop("indent", 2)
|
||||
return json.dumps(obj, default=default, indent=indent, **kwargs)
|
||||
|
@ -1,7 +1,9 @@
|
||||
from pydantic import ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from langchain_core.load import Serializable, dumpd, load
|
||||
from langchain_core.load.serializable import _is_field_useful
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
|
||||
|
||||
class NonBoolObj:
|
||||
@ -203,3 +205,21 @@ def test_str() -> None:
|
||||
non_bool=NonBoolObj(),
|
||||
)
|
||||
assert str(foo) == "content='str' non_bool=NonBoolObj"
|
||||
|
||||
|
||||
def test_serialization_with_pydantic() -> None:
|
||||
class MyModel(BaseModel):
|
||||
x: int
|
||||
y: str
|
||||
|
||||
my_model = MyModel(x=1, y="hello")
|
||||
llm_response = ChatGeneration(
|
||||
message=AIMessage(
|
||||
content='{"x": 1, "y": "hello"}', additional_kwargs={"parsed": my_model}
|
||||
)
|
||||
)
|
||||
ser = dumpd(llm_response)
|
||||
deser = load(ser)
|
||||
assert isinstance(deser, ChatGeneration)
|
||||
assert deser.message.content
|
||||
assert deser.message.additional_kwargs["parsed"] == my_model.model_dump()
|
||||
|
@ -9,6 +9,7 @@ import os
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from math import ceil
|
||||
from operator import itemgetter
|
||||
@ -78,7 +79,12 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
parse_tool_call,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough, chain
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableLambda,
|
||||
RunnableMap,
|
||||
RunnablePassthrough,
|
||||
)
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import get_pydantic_field_names
|
||||
@ -1436,9 +1442,9 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
},
|
||||
)
|
||||
if is_pydantic_schema:
|
||||
output_parser = _oai_structured_outputs_parser.with_types(
|
||||
output_type=cast(type, schema)
|
||||
)
|
||||
output_parser = RunnableLambda(
|
||||
partial(_oai_structured_outputs_parser, schema=cast(type, schema))
|
||||
).with_types(output_type=cast(type, schema))
|
||||
else:
|
||||
output_parser = JsonOutputParser()
|
||||
else:
|
||||
@ -2517,10 +2523,14 @@ def _convert_to_openai_response_format(
|
||||
return response_format
|
||||
|
||||
|
||||
@chain
|
||||
def _oai_structured_outputs_parser(ai_msg: AIMessage) -> PydanticBaseModel:
|
||||
if ai_msg.additional_kwargs.get("parsed"):
|
||||
return ai_msg.additional_kwargs["parsed"]
|
||||
def _oai_structured_outputs_parser(
|
||||
ai_msg: AIMessage, schema: Type[_BM]
|
||||
) -> PydanticBaseModel:
|
||||
if parsed := ai_msg.additional_kwargs.get("parsed"):
|
||||
if isinstance(parsed, dict):
|
||||
return schema(**parsed)
|
||||
else:
|
||||
return parsed
|
||||
elif ai_msg.additional_kwargs.get("refusal"):
|
||||
raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"])
|
||||
else:
|
||||
|
@ -1,11 +1,13 @@
|
||||
"""Test OpenAI Chat API wrapper."""
|
||||
|
||||
import json
|
||||
from functools import partial
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.load import dumps, loads
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
@ -17,6 +19,8 @@ from langchain_core.messages import (
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.ai import UsageMetadata
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
@ -27,6 +31,7 @@ from langchain_openai.chat_models.base import (
|
||||
_convert_to_openai_response_format,
|
||||
_create_usage_metadata,
|
||||
_format_message_content,
|
||||
_oai_structured_outputs_parser,
|
||||
)
|
||||
|
||||
|
||||
@ -913,3 +918,21 @@ def test_structured_output_old_model() -> None:
|
||||
# assert tool calling was used instead of json_schema
|
||||
assert "tools" in llm.steps[0].kwargs # type: ignore
|
||||
assert "response_format" not in llm.steps[0].kwargs # type: ignore
|
||||
|
||||
|
||||
def test_structured_outputs_parser() -> None:
|
||||
parsed_response = GenerateUsername(name="alice", hair_color="black")
|
||||
llm_output = ChatGeneration(
|
||||
message=AIMessage(
|
||||
content='{"name": "alice", "hair_color": "black"}',
|
||||
additional_kwargs={"parsed": parsed_response},
|
||||
)
|
||||
)
|
||||
output_parser = RunnableLambda(
|
||||
partial(_oai_structured_outputs_parser, schema=GenerateUsername)
|
||||
)
|
||||
serialized = dumps(llm_output)
|
||||
deserialized = loads(serialized)
|
||||
assert isinstance(deserialized, ChatGeneration)
|
||||
result = output_parser.invoke(deserialized.message)
|
||||
assert result == parsed_response
|
||||
|
Loading…
Reference in New Issue
Block a user