mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-27 06:18:05 +00:00
core[patch]: Fix _get_type in AnyMessage (#26223)
Fix _get_type to work on deserialization path as well and add a unit test.
This commit is contained in:
@@ -50,7 +50,16 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def _get_type(v: Any) -> str:
|
||||
return v.type
|
||||
"""Get the type associated with the object for serialization purposes."""
|
||||
if isinstance(v, dict) and "type" in v:
|
||||
return v["type"]
|
||||
elif hasattr(v, "type"):
|
||||
return v.type
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected either a dictionary with a 'type' key or an object "
|
||||
f"with a 'type' attribute. Instead got type {type(v)}."
|
||||
)
|
||||
|
||||
|
||||
AnyMessage = Annotated[
|
||||
|
64
libs/core/tests/unit_tests/test_pydantic_serde.py
Normal file
64
libs/core/tests/unit_tests/test_pydantic_serde.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""A set of tests that verifies that Union discrimination works correctly with
|
||||
the various pydantic base models.
|
||||
|
||||
These tests can uncover issues that will also arise during regular instantiation
|
||||
of the models (i.e., not necessarily from loading or dumping JSON).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pydantic import RootModel, ValidationError
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
AnyMessage,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
FunctionMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
)
|
||||
|
||||
|
||||
def test_serde_any_message() -> None:
|
||||
"""Test AnyMessage() serder."""
|
||||
|
||||
lc_objects = [
|
||||
HumanMessage(content="human"),
|
||||
HumanMessageChunk(content="human"),
|
||||
AIMessage(content="ai"),
|
||||
AIMessageChunk(content="ai"),
|
||||
SystemMessage(content="sys"),
|
||||
SystemMessageChunk(content="sys"),
|
||||
FunctionMessage(
|
||||
name="func",
|
||||
content="func",
|
||||
),
|
||||
FunctionMessageChunk(
|
||||
name="func",
|
||||
content="func",
|
||||
),
|
||||
ChatMessage(
|
||||
role="human",
|
||||
content="human",
|
||||
),
|
||||
ChatMessageChunk(
|
||||
role="human",
|
||||
content="human",
|
||||
),
|
||||
]
|
||||
|
||||
Model = RootModel[AnyMessage]
|
||||
|
||||
for lc_object in lc_objects:
|
||||
d = lc_object.model_dump()
|
||||
assert "type" in d, f"Missing key `type` for {type(lc_object)}"
|
||||
obj1 = Model.model_validate(d)
|
||||
assert type(obj1.root) is type(lc_object), f"failed for {type(lc_object)}"
|
||||
|
||||
with pytest.raises((TypeError, ValidationError)):
|
||||
# Make sure that specifically validation error is raised
|
||||
Model.model_validate({})
|
Reference in New Issue
Block a user