mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-11 13:55:03 +00:00
core, openai[patch]: support serialization of pydantic models in messages (#29940)
Resolves https://github.com/langchain-ai/langchain/issues/29003, https://github.com/langchain-ai/langchain/issues/27264 Related: https://github.com/langchain-ai/langchain-redis/issues/52 ```python from langchain.chat_models import init_chat_model from langchain.globals import set_llm_cache from langchain_community.cache import SQLiteCache from pydantic import BaseModel cache = SQLiteCache() set_llm_cache(cache) class Temperature(BaseModel): value: int city: str llm = init_chat_model("openai:gpt-4o-mini") structured_llm = llm.with_structured_output(Temperature) ``` ```python # 681 ms response = structured_llm.invoke("What is the average temperature of Rome in May?") ``` ```python # 6.98 ms response = structured_llm.invoke("What is the average temperature of Rome in May?") ```
This commit is contained in:
parent
1645ec1890
commit
b1a7f4e106
@ -1,6 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from langchain_core.load.serializable import Serializable, to_json_not_implemented
|
from langchain_core.load.serializable import Serializable, to_json_not_implemented
|
||||||
|
|
||||||
|
|
||||||
@ -20,6 +22,23 @@ def default(obj: Any) -> Any:
|
|||||||
return to_json_not_implemented(obj)
|
return to_json_not_implemented(obj)
|
||||||
|
|
||||||
|
|
||||||
|
def _dump_pydantic_models(obj: Any) -> Any:
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
from langchain_core.outputs import ChatGeneration
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(obj, ChatGeneration)
|
||||||
|
and isinstance(obj.message, AIMessage)
|
||||||
|
and (parsed := obj.message.additional_kwargs.get("parsed"))
|
||||||
|
and isinstance(parsed, BaseModel)
|
||||||
|
):
|
||||||
|
obj_copy = obj.model_copy(deep=True)
|
||||||
|
obj_copy.message.additional_kwargs["parsed"] = parsed.model_dump()
|
||||||
|
return obj_copy
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str:
|
def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str:
|
||||||
"""Return a json string representation of an object.
|
"""Return a json string representation of an object.
|
||||||
|
|
||||||
@ -40,6 +59,7 @@ def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str:
|
|||||||
msg = "`default` should not be passed to dumps"
|
msg = "`default` should not be passed to dumps"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
try:
|
try:
|
||||||
|
obj = _dump_pydantic_models(obj)
|
||||||
if pretty:
|
if pretty:
|
||||||
indent = kwargs.pop("indent", 2)
|
indent = kwargs.pop("indent", 2)
|
||||||
return json.dumps(obj, default=default, indent=indent, **kwargs)
|
return json.dumps(obj, default=default, indent=indent, **kwargs)
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
from pydantic import ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from langchain_core.load import Serializable, dumpd, load
|
from langchain_core.load import Serializable, dumpd, load
|
||||||
from langchain_core.load.serializable import _is_field_useful
|
from langchain_core.load.serializable import _is_field_useful
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
from langchain_core.outputs import ChatGeneration
|
||||||
|
|
||||||
|
|
||||||
class NonBoolObj:
|
class NonBoolObj:
|
||||||
@ -203,3 +205,21 @@ def test_str() -> None:
|
|||||||
non_bool=NonBoolObj(),
|
non_bool=NonBoolObj(),
|
||||||
)
|
)
|
||||||
assert str(foo) == "content='str' non_bool=NonBoolObj"
|
assert str(foo) == "content='str' non_bool=NonBoolObj"
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialization_with_pydantic() -> None:
|
||||||
|
class MyModel(BaseModel):
|
||||||
|
x: int
|
||||||
|
y: str
|
||||||
|
|
||||||
|
my_model = MyModel(x=1, y="hello")
|
||||||
|
llm_response = ChatGeneration(
|
||||||
|
message=AIMessage(
|
||||||
|
content='{"x": 1, "y": "hello"}', additional_kwargs={"parsed": my_model}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ser = dumpd(llm_response)
|
||||||
|
deser = load(ser)
|
||||||
|
assert isinstance(deser, ChatGeneration)
|
||||||
|
assert deser.message.content
|
||||||
|
assert deser.message.additional_kwargs["parsed"] == my_model.model_dump()
|
||||||
|
@ -9,6 +9,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
from functools import partial
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
@ -78,7 +79,12 @@ from langchain_core.output_parsers.openai_tools import (
|
|||||||
parse_tool_call,
|
parse_tool_call,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough, chain
|
from langchain_core.runnables import (
|
||||||
|
Runnable,
|
||||||
|
RunnableLambda,
|
||||||
|
RunnableMap,
|
||||||
|
RunnablePassthrough,
|
||||||
|
)
|
||||||
from langchain_core.runnables.config import run_in_executor
|
from langchain_core.runnables.config import run_in_executor
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.utils import get_pydantic_field_names
|
from langchain_core.utils import get_pydantic_field_names
|
||||||
@ -1436,9 +1442,9 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
if is_pydantic_schema:
|
if is_pydantic_schema:
|
||||||
output_parser = _oai_structured_outputs_parser.with_types(
|
output_parser = RunnableLambda(
|
||||||
output_type=cast(type, schema)
|
partial(_oai_structured_outputs_parser, schema=cast(type, schema))
|
||||||
)
|
).with_types(output_type=cast(type, schema))
|
||||||
else:
|
else:
|
||||||
output_parser = JsonOutputParser()
|
output_parser = JsonOutputParser()
|
||||||
else:
|
else:
|
||||||
@ -2517,10 +2523,14 @@ def _convert_to_openai_response_format(
|
|||||||
return response_format
|
return response_format
|
||||||
|
|
||||||
|
|
||||||
@chain
|
def _oai_structured_outputs_parser(
|
||||||
def _oai_structured_outputs_parser(ai_msg: AIMessage) -> PydanticBaseModel:
|
ai_msg: AIMessage, schema: Type[_BM]
|
||||||
if ai_msg.additional_kwargs.get("parsed"):
|
) -> PydanticBaseModel:
|
||||||
return ai_msg.additional_kwargs["parsed"]
|
if parsed := ai_msg.additional_kwargs.get("parsed"):
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
return schema(**parsed)
|
||||||
|
else:
|
||||||
|
return parsed
|
||||||
elif ai_msg.additional_kwargs.get("refusal"):
|
elif ai_msg.additional_kwargs.get("refusal"):
|
||||||
raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"])
|
raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"])
|
||||||
else:
|
else:
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
"""Test OpenAI Chat API wrapper."""
|
"""Test OpenAI Chat API wrapper."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from functools import partial
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from langchain_core.load import dumps, loads
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
AIMessageChunk,
|
AIMessageChunk,
|
||||||
@ -17,6 +19,8 @@ from langchain_core.messages import (
|
|||||||
ToolMessage,
|
ToolMessage,
|
||||||
)
|
)
|
||||||
from langchain_core.messages.ai import UsageMetadata
|
from langchain_core.messages.ai import UsageMetadata
|
||||||
|
from langchain_core.outputs import ChatGeneration
|
||||||
|
from langchain_core.runnables import RunnableLambda
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
@ -27,6 +31,7 @@ from langchain_openai.chat_models.base import (
|
|||||||
_convert_to_openai_response_format,
|
_convert_to_openai_response_format,
|
||||||
_create_usage_metadata,
|
_create_usage_metadata,
|
||||||
_format_message_content,
|
_format_message_content,
|
||||||
|
_oai_structured_outputs_parser,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -913,3 +918,21 @@ def test_structured_output_old_model() -> None:
|
|||||||
# assert tool calling was used instead of json_schema
|
# assert tool calling was used instead of json_schema
|
||||||
assert "tools" in llm.steps[0].kwargs # type: ignore
|
assert "tools" in llm.steps[0].kwargs # type: ignore
|
||||||
assert "response_format" not in llm.steps[0].kwargs # type: ignore
|
assert "response_format" not in llm.steps[0].kwargs # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_outputs_parser() -> None:
|
||||||
|
parsed_response = GenerateUsername(name="alice", hair_color="black")
|
||||||
|
llm_output = ChatGeneration(
|
||||||
|
message=AIMessage(
|
||||||
|
content='{"name": "alice", "hair_color": "black"}',
|
||||||
|
additional_kwargs={"parsed": parsed_response},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
output_parser = RunnableLambda(
|
||||||
|
partial(_oai_structured_outputs_parser, schema=GenerateUsername)
|
||||||
|
)
|
||||||
|
serialized = dumps(llm_output)
|
||||||
|
deserialized = loads(serialized)
|
||||||
|
assert isinstance(deserialized, ChatGeneration)
|
||||||
|
result = output_parser.invoke(deserialized.message)
|
||||||
|
assert result == parsed_response
|
||||||
|
Loading…
Reference in New Issue
Block a user