From b8fc82b84ba25515aa4e8740ebbfd0677208af21 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 9 Sep 2024 10:33:18 -0400 Subject: [PATCH] core[patch]: Fix _get_type in AnyMessage (#26223) Fix _get_type to work on deserialization path as well and add a unit test. --- libs/core/langchain_core/messages/utils.py | 11 +++- .../tests/unit_tests/test_pydantic_serde.py | 64 +++++++++++++++++++ 2 files changed, 74 insertions(+), 1 deletion(-) create mode 100644 libs/core/tests/unit_tests/test_pydantic_serde.py diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 8059280044f..ef3eb6e5039 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -50,7 +50,16 @@ if TYPE_CHECKING: def _get_type(v: Any) -> str: - return v.type + """Get the type associated with the object for serialization purposes.""" + if isinstance(v, dict) and "type" in v: + return v["type"] + elif hasattr(v, "type"): + return v.type + else: + raise TypeError( + f"Expected either a dictionary with a 'type' key or an object " + f"with a 'type' attribute. Instead got type {type(v)}." + ) AnyMessage = Annotated[ diff --git a/libs/core/tests/unit_tests/test_pydantic_serde.py b/libs/core/tests/unit_tests/test_pydantic_serde.py new file mode 100644 index 00000000000..c19d14cb249 --- /dev/null +++ b/libs/core/tests/unit_tests/test_pydantic_serde.py @@ -0,0 +1,64 @@ +"""A set of tests that verifies that Union discrimination works correctly with +the various pydantic base models. + +These tests can uncover issues that will also arise during regular instantiation +of the models (i.e., not necessarily from loading or dumping JSON). +""" + +import pytest +from pydantic import RootModel, ValidationError + +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + AnyMessage, + ChatMessage, + ChatMessageChunk, + FunctionMessage, + FunctionMessageChunk, + HumanMessage, + HumanMessageChunk, + SystemMessage, + SystemMessageChunk, +) + + +def test_serde_any_message() -> None: + """Test AnyMessage() serder.""" + + lc_objects = [ + HumanMessage(content="human"), + HumanMessageChunk(content="human"), + AIMessage(content="ai"), + AIMessageChunk(content="ai"), + SystemMessage(content="sys"), + SystemMessageChunk(content="sys"), + FunctionMessage( + name="func", + content="func", + ), + FunctionMessageChunk( + name="func", + content="func", + ), + ChatMessage( + role="human", + content="human", + ), + ChatMessageChunk( + role="human", + content="human", + ), + ] + + Model = RootModel[AnyMessage] + + for lc_object in lc_objects: + d = lc_object.model_dump() + assert "type" in d, f"Missing key `type` for {type(lc_object)}" + obj1 = Model.model_validate(d) + assert type(obj1.root) is type(lc_object), f"failed for {type(lc_object)}" + + with pytest.raises((TypeError, ValidationError)): + # Make sure that specifically validation error is raised + Model.model_validate({})