diff --git a/libs/core/langchain_core/messages/base.py b/libs/core/langchain_core/messages/base.py index 9eab1ed431a..bcc3e86d239 100644 --- a/libs/core/langchain_core/messages/base.py +++ b/libs/core/langchain_core/messages/base.py @@ -118,6 +118,9 @@ class BaseMessage(Serializable): def pretty_print(self) -> None: print(self.pretty_repr(html=is_interactive_env())) # noqa: T201 + def __getitem__(self, item: str) -> Any: + return self.model_dump()[item] + def merge_content( first_content: Union[str, list[Union[str, dict]]], diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index aafcd15e1bb..b42a50f5300 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -1008,3 +1008,9 @@ def test_tool_message_tool_call_id() -> None: ToolMessage("foo", tool_call_id=uuid.uuid4()) ToolMessage("foo", tool_call_id=1) ToolMessage("foo", tool_call_id=1.0) + + +def test_message_getitem() -> None: + msg = BaseMessage(content="foo", role="bar", id=1, type="baz") + for k in msg.model_fields: + assert msg[k] == getattr(msg, k)