mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-28 06:48:50 +00:00
core[patch]: Fix _get_type in AnyMessage (#26223)
Fix _get_type to work on deserialization path as well and add a unit test.
This commit is contained in:
64
libs/core/tests/unit_tests/test_pydantic_serde.py
Normal file
64
libs/core/tests/unit_tests/test_pydantic_serde.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""A set of tests that verifies that Union discrimination works correctly with
|
||||
the various pydantic base models.
|
||||
|
||||
These tests can uncover issues that will also arise during regular instantiation
|
||||
of the models (i.e., not necessarily from loading or dumping JSON).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pydantic import RootModel, ValidationError
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
AnyMessage,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
FunctionMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
)
|
||||
|
||||
|
||||
def test_serde_any_message() -> None:
|
||||
"""Test AnyMessage() serder."""
|
||||
|
||||
lc_objects = [
|
||||
HumanMessage(content="human"),
|
||||
HumanMessageChunk(content="human"),
|
||||
AIMessage(content="ai"),
|
||||
AIMessageChunk(content="ai"),
|
||||
SystemMessage(content="sys"),
|
||||
SystemMessageChunk(content="sys"),
|
||||
FunctionMessage(
|
||||
name="func",
|
||||
content="func",
|
||||
),
|
||||
FunctionMessageChunk(
|
||||
name="func",
|
||||
content="func",
|
||||
),
|
||||
ChatMessage(
|
||||
role="human",
|
||||
content="human",
|
||||
),
|
||||
ChatMessageChunk(
|
||||
role="human",
|
||||
content="human",
|
||||
),
|
||||
]
|
||||
|
||||
Model = RootModel[AnyMessage]
|
||||
|
||||
for lc_object in lc_objects:
|
||||
d = lc_object.model_dump()
|
||||
assert "type" in d, f"Missing key `type` for {type(lc_object)}"
|
||||
obj1 = Model.model_validate(d)
|
||||
assert type(obj1.root) is type(lc_object), f"failed for {type(lc_object)}"
|
||||
|
||||
with pytest.raises((TypeError, ValidationError)):
|
||||
# Make sure that specifically validation error is raised
|
||||
Model.model_validate({})
|
Reference in New Issue
Block a user