mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-27 14:26:48 +00:00
core[patch]: Fix _get_type in AnyMessage (#26223)
Fix _get_type to work on deserialization path as well and add a unit test.
This commit is contained in:
@@ -50,7 +50,16 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
def _get_type(v: Any) -> str:
|
def _get_type(v: Any) -> str:
|
||||||
return v.type
|
"""Get the type associated with the object for serialization purposes."""
|
||||||
|
if isinstance(v, dict) and "type" in v:
|
||||||
|
return v["type"]
|
||||||
|
elif hasattr(v, "type"):
|
||||||
|
return v.type
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"Expected either a dictionary with a 'type' key or an object "
|
||||||
|
f"with a 'type' attribute. Instead got type {type(v)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
AnyMessage = Annotated[
|
AnyMessage = Annotated[
|
||||||
|
64
libs/core/tests/unit_tests/test_pydantic_serde.py
Normal file
64
libs/core/tests/unit_tests/test_pydantic_serde.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
"""A set of tests that verifies that Union discrimination works correctly with
|
||||||
|
the various pydantic base models.
|
||||||
|
|
||||||
|
These tests can uncover issues that will also arise during regular instantiation
|
||||||
|
of the models (i.e., not necessarily from loading or dumping JSON).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import RootModel, ValidationError
|
||||||
|
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
AIMessageChunk,
|
||||||
|
AnyMessage,
|
||||||
|
ChatMessage,
|
||||||
|
ChatMessageChunk,
|
||||||
|
FunctionMessage,
|
||||||
|
FunctionMessageChunk,
|
||||||
|
HumanMessage,
|
||||||
|
HumanMessageChunk,
|
||||||
|
SystemMessage,
|
||||||
|
SystemMessageChunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_serde_any_message() -> None:
|
||||||
|
"""Test AnyMessage() serder."""
|
||||||
|
|
||||||
|
lc_objects = [
|
||||||
|
HumanMessage(content="human"),
|
||||||
|
HumanMessageChunk(content="human"),
|
||||||
|
AIMessage(content="ai"),
|
||||||
|
AIMessageChunk(content="ai"),
|
||||||
|
SystemMessage(content="sys"),
|
||||||
|
SystemMessageChunk(content="sys"),
|
||||||
|
FunctionMessage(
|
||||||
|
name="func",
|
||||||
|
content="func",
|
||||||
|
),
|
||||||
|
FunctionMessageChunk(
|
||||||
|
name="func",
|
||||||
|
content="func",
|
||||||
|
),
|
||||||
|
ChatMessage(
|
||||||
|
role="human",
|
||||||
|
content="human",
|
||||||
|
),
|
||||||
|
ChatMessageChunk(
|
||||||
|
role="human",
|
||||||
|
content="human",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
Model = RootModel[AnyMessage]
|
||||||
|
|
||||||
|
for lc_object in lc_objects:
|
||||||
|
d = lc_object.model_dump()
|
||||||
|
assert "type" in d, f"Missing key `type` for {type(lc_object)}"
|
||||||
|
obj1 = Model.model_validate(d)
|
||||||
|
assert type(obj1.root) is type(lc_object), f"failed for {type(lc_object)}"
|
||||||
|
|
||||||
|
with pytest.raises((TypeError, ValidationError)):
|
||||||
|
# Make sure that specifically validation error is raised
|
||||||
|
Model.model_validate({})
|
Reference in New Issue
Block a user