Compare commits

...

2 Commits

Author SHA1 Message Date
Bagatur
5f87276c3a cr 2025-01-15 13:35:04 -08:00
Bagatur
cde3b1eb48 core[patch]: make BaseMessage subscriptable 2025-01-15 13:04:06 -08:00
2 changed files with 12 additions and 0 deletions

View File

@@ -118,6 +118,12 @@ class BaseMessage(Serializable):
def pretty_print(self) -> None:
print(self.pretty_repr(html=is_interactive_env())) # noqa: T201
def __getitem__(self, item: str) -> Any:
if item in self.model_fields and hasattr(self, item):
return getattr(self, item)
else:
raise KeyError(item)
def merge_content(
first_content: Union[str, list[Union[str, dict]]],

View File

@@ -1008,3 +1008,9 @@ def test_tool_message_tool_call_id() -> None:
ToolMessage("foo", tool_call_id=uuid.uuid4())
ToolMessage("foo", tool_call_id=1)
ToolMessage("foo", tool_call_id=1.0)
def test_message_getitem() -> None:
msg = BaseMessage(content="foo", role="bar", id=1, type="baz")
for k in msg.model_fields:
assert msg[k] == getattr(msg, k)