messages.v1 mypy fixes

This commit is contained in:
Mason Daugherty 2025-07-31 17:32:59 -04:00
parent 7a0c3e0482
commit 05d075a37f
No known key found for this signature in database

View File

@ -177,9 +177,10 @@ class AIMessage:
if "id" in tool_call and tool_call["id"] in content_tool_calls: if "id" in tool_call and tool_call["id"] in content_tool_calls:
continue continue
self.content.append(tool_call) self.content.append(tool_call)
self._tool_calls = [ self._tool_calls: list[types.ToolCall] = cast(
block for block in self.content if block.get("type") == "tool_call" "list[types.ToolCall]",
] [block for block in self.content if block.get("type") == "tool_call"],
)
self.invalid_tool_calls = invalid_tool_calls or [] self.invalid_tool_calls = invalid_tool_calls or []
@property @property
@ -187,7 +188,9 @@ class AIMessage:
"""Extract all text content from the AI message as a string.""" """Extract all text content from the AI message as a string."""
text_blocks = [block for block in self.content if block.get("type") == "text"] text_blocks = [block for block in self.content if block.get("type") == "text"]
if text_blocks: if text_blocks:
return "".join(block["text"] for block in text_blocks) return "".join(
cast("types.TextContentBlock", block)["text"] for block in text_blocks
)
return None return None
@property @property
@ -195,10 +198,12 @@ class AIMessage:
"""Get the tool calls made by the AI.""" """Get the tool calls made by the AI."""
if self._tool_calls: if self._tool_calls:
return self._tool_calls return self._tool_calls
tool_calls = [block for block in self.content if block["type"] == "tool_call"] tool_calls = [
block for block in self.content if block.get("type") == "tool_call"
]
if tool_calls: if tool_calls:
self._tool_calls = tool_calls self._tool_calls = cast("list[types.ToolCall]", tool_calls)
return [block for block in self.content if block["type"] == "tool_call"] return self._tool_calls
@tool_calls.setter @tool_calls.setter
def tool_calls(self, value: list[types.ToolCall]) -> None: def tool_calls(self, value: list[types.ToolCall]) -> None:
@ -351,9 +356,11 @@ class AIMessageChunk(AIMessage):
@property @property
def text(self) -> Optional[str]: def text(self) -> Optional[str]:
"""Extract all text content from the AI message as a string.""" """Extract all text content from the AI message as a string."""
text_blocks = [block for block in self.content if block["type"] == "text"] text_blocks = [block for block in self.content if block.get("type") == "text"]
if text_blocks: if text_blocks:
return "".join(block["text"] for block in text_blocks) return "".join(
cast("types.TextContentBlock", block)["text"] for block in text_blocks
)
return None return None
@property @property
@ -365,7 +372,10 @@ class AIMessageChunk(AIMessage):
if block.get("type") == "reasoning" and "reasoning" in block if block.get("type") == "reasoning" and "reasoning" in block
] ]
if text_blocks: if text_blocks:
return "".join(block["reasoning"] for block in text_blocks) return "".join(
cast("types.ReasoningContentBlock", block).get("reasoning", "")
for block in text_blocks
)
return None return None
@property @property
@ -377,7 +387,7 @@ class AIMessageChunk(AIMessage):
block for block in self.content if block.get("type") == "tool_call" block for block in self.content if block.get("type") == "tool_call"
] ]
if tool_calls: if tool_calls:
self._tool_calls = tool_calls self._tool_calls = cast("list[types.ToolCall]", tool_calls)
return self._tool_calls return self._tool_calls
@tool_calls.setter @tool_calls.setter
@ -561,7 +571,9 @@ class HumanMessage:
Concatenated string of all text blocks in the message. Concatenated string of all text blocks in the message.
""" """
return "".join( return "".join(
block["text"] for block in self.content if block.get("type") == "text" cast("types.TextContentBlock", block)["text"]
for block in self.content
if block.get("type") == "text"
) )
@ -640,7 +652,9 @@ class SystemMessage:
def text(self) -> str: def text(self) -> str:
"""Extract all text content from the system message.""" """Extract all text content from the system message."""
return "".join( return "".join(
block["text"] for block in self.content if block.get("type") == "text" cast("types.TextContentBlock", block)["text"]
for block in self.content
if block.get("type") == "text"
) )
@ -732,7 +746,9 @@ class ToolMessage:
def text(self) -> str: def text(self) -> str:
"""Extract all text content from the tool message.""" """Extract all text content from the tool message."""
return "".join( return "".join(
block["text"] for block in self.content if block.get("type") == "text" cast("types.TextContentBlock", block)["text"]
for block in self.content
if block.get("type") == "text"
) )
def __post_init__(self) -> None: def __post_init__(self) -> None: