core[patch]: Fix _get_type in AnyMessage (#26223)

Fix _get_type to work on deserialization path as well and add a unit test.
This commit is contained in:
Eugene Yurtsev
2024-09-09 10:33:18 -04:00
committed by GitHub
parent 179eeead81
commit b8fc82b84b
2 changed files with 74 additions and 1 deletions

View File

@@ -50,7 +50,16 @@ if TYPE_CHECKING:
def _get_type(v: Any) -> str: def _get_type(v: Any) -> str:
return v.type """Get the type associated with the object for serialization purposes."""
if isinstance(v, dict) and "type" in v:
return v["type"]
elif hasattr(v, "type"):
return v.type
else:
raise TypeError(
f"Expected either a dictionary with a 'type' key or an object "
f"with a 'type' attribute. Instead got type {type(v)}."
)
AnyMessage = Annotated[ AnyMessage = Annotated[

View File

@@ -0,0 +1,64 @@
"""A set of tests that verifies that Union discrimination works correctly with
the various pydantic base models.
These tests can uncover issues that will also arise during regular instantiation
of the models (i.e., not necessarily from loading or dumping JSON).
"""
import pytest
from pydantic import RootModel, ValidationError
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
AnyMessage,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
)
def test_serde_any_message() -> None:
"""Test AnyMessage() serder."""
lc_objects = [
HumanMessage(content="human"),
HumanMessageChunk(content="human"),
AIMessage(content="ai"),
AIMessageChunk(content="ai"),
SystemMessage(content="sys"),
SystemMessageChunk(content="sys"),
FunctionMessage(
name="func",
content="func",
),
FunctionMessageChunk(
name="func",
content="func",
),
ChatMessage(
role="human",
content="human",
),
ChatMessageChunk(
role="human",
content="human",
),
]
Model = RootModel[AnyMessage]
for lc_object in lc_objects:
d = lc_object.model_dump()
assert "type" in d, f"Missing key `type` for {type(lc_object)}"
obj1 = Model.model_validate(d)
assert type(obj1.root) is type(lc_object), f"failed for {type(lc_object)}"
with pytest.raises((TypeError, ValidationError)):
# Make sure that specifically validation error is raised
Model.model_validate({})