mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-10 23:41:28 +00:00
core, openai[patch]: support serialization of pydantic models in messages (#29940)
Resolves https://github.com/langchain-ai/langchain/issues/29003, https://github.com/langchain-ai/langchain/issues/27264 Related: https://github.com/langchain-ai/langchain-redis/issues/52 ```python from langchain.chat_models import init_chat_model from langchain.globals import set_llm_cache from langchain_community.cache import SQLiteCache from pydantic import BaseModel cache = SQLiteCache() set_llm_cache(cache) class Temperature(BaseModel): value: int city: str llm = init_chat_model("openai:gpt-4o-mini") structured_llm = llm.with_structured_output(Temperature) ``` ```python # 681 ms response = structured_llm.invoke("What is the average temperature of Rome in May?") ``` ```python # 6.98 ms response = structured_llm.invoke("What is the average temperature of Rome in May?") ```
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_core.load.serializable import Serializable, to_json_not_implemented
|
||||
|
||||
|
||||
@@ -20,6 +22,23 @@ def default(obj: Any) -> Any:
|
||||
return to_json_not_implemented(obj)
|
||||
|
||||
|
||||
def _dump_pydantic_models(obj: Any) -> Any:
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
|
||||
if (
|
||||
isinstance(obj, ChatGeneration)
|
||||
and isinstance(obj.message, AIMessage)
|
||||
and (parsed := obj.message.additional_kwargs.get("parsed"))
|
||||
and isinstance(parsed, BaseModel)
|
||||
):
|
||||
obj_copy = obj.model_copy(deep=True)
|
||||
obj_copy.message.additional_kwargs["parsed"] = parsed.model_dump()
|
||||
return obj_copy
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str:
|
||||
"""Return a json string representation of an object.
|
||||
|
||||
@@ -40,6 +59,7 @@ def dumps(obj: Any, *, pretty: bool = False, **kwargs: Any) -> str:
|
||||
msg = "`default` should not be passed to dumps"
|
||||
raise ValueError(msg)
|
||||
try:
|
||||
obj = _dump_pydantic_models(obj)
|
||||
if pretty:
|
||||
indent = kwargs.pop("indent", 2)
|
||||
return json.dumps(obj, default=default, indent=indent, **kwargs)
|
||||
|
@@ -1,7 +1,9 @@
|
||||
from pydantic import ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from langchain_core.load import Serializable, dumpd, load
|
||||
from langchain_core.load.serializable import _is_field_useful
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
|
||||
|
||||
class NonBoolObj:
|
||||
@@ -203,3 +205,21 @@ def test_str() -> None:
|
||||
non_bool=NonBoolObj(),
|
||||
)
|
||||
assert str(foo) == "content='str' non_bool=NonBoolObj"
|
||||
|
||||
|
||||
def test_serialization_with_pydantic() -> None:
|
||||
class MyModel(BaseModel):
|
||||
x: int
|
||||
y: str
|
||||
|
||||
my_model = MyModel(x=1, y="hello")
|
||||
llm_response = ChatGeneration(
|
||||
message=AIMessage(
|
||||
content='{"x": 1, "y": "hello"}', additional_kwargs={"parsed": my_model}
|
||||
)
|
||||
)
|
||||
ser = dumpd(llm_response)
|
||||
deser = load(ser)
|
||||
assert isinstance(deser, ChatGeneration)
|
||||
assert deser.message.content
|
||||
assert deser.message.additional_kwargs["parsed"] == my_model.model_dump()
|
||||
|
Reference in New Issue
Block a user