type guards to remove casting

This commit is contained in:
Mason Daugherty 2025-08-04 11:29:36 -04:00
parent 822dd5075c
commit a6686d7c4f
No known key found for this signature in database

View File

@ -7,7 +7,7 @@ Each message has content that may be comprised of content blocks, defined under
import json
import uuid
from dataclasses import dataclass, field
from typing import Any, Literal, Optional, Union, cast, get_args
from typing import Any, Literal, Optional, TypeGuard, Union, cast, get_args
from pydantic import BaseModel
from typing_extensions import TypedDict
@ -29,6 +29,23 @@ from langchain_core.utils._merge import merge_dicts, merge_lists
from langchain_core.utils.json import parse_partial_json
def is_tool_call_block(block: types.ContentBlock) -> TypeGuard[types.ToolCall]:
"""Type guard to check if a content block is a tool call."""
return block.get("type") == "tool_call"
def is_text_block(block: types.ContentBlock) -> TypeGuard[types.TextContentBlock]:
"""Type guard to check if a content block is a text block."""
return block.get("type") == "text"
def is_invalid_tool_call_block(
block: types.ContentBlock,
) -> TypeGuard[types.InvalidToolCall]:
"""Type guard to check if a content block is an invalid tool call."""
return block.get("type") == "invalid_tool_call"
def _ensure_id(id_val: Optional[str]) -> str:
"""Ensure the ID is a valid string, generating a new UUID if not provided.
@ -177,20 +194,17 @@ class AIMessage:
if "id" in tool_call and tool_call["id"] in content_tool_calls:
continue
self.content.append(tool_call)
self._tool_calls: list[types.ToolCall] = cast(
"list[types.ToolCall]",
[block for block in self.content if block.get("type") == "tool_call"],
)
self._tool_calls: list[types.ToolCall] = [
block for block in self.content if is_tool_call_block(block)
]
self.invalid_tool_calls = invalid_tool_calls or []
@property
def text(self) -> Optional[str]:
"""Extract all text content from the AI message as a string."""
text_blocks = [block for block in self.content if block.get("type") == "text"]
text_blocks = [block for block in self.content if is_text_block(block)]
if text_blocks:
return "".join(
cast("types.TextContentBlock", block)["text"] for block in text_blocks
)
return "".join(block["text"] for block in text_blocks)
return None
@property
@ -198,11 +212,9 @@ class AIMessage:
"""Get the tool calls made by the AI."""
if self._tool_calls:
return self._tool_calls
tool_calls = [
block for block in self.content if block.get("type") == "tool_call"
]
tool_calls = [block for block in self.content if is_tool_call_block(block)]
if tool_calls:
self._tool_calls = cast("list[types.ToolCall]", tool_calls)
self._tool_calls = tool_calls
return self._tool_calls
@tool_calls.setter
@ -356,11 +368,9 @@ class AIMessageChunk(AIMessage):
@property
def text(self) -> Optional[str]:
"""Extract all text content from the AI message as a string."""
text_blocks = [block for block in self.content if block.get("type") == "text"]
text_blocks = [block for block in self.content if is_text_block(block)]
if text_blocks:
return "".join(
cast("types.TextContentBlock", block)["text"] for block in text_blocks
)
return "".join(block["text"] for block in text_blocks)
return None
@property
@ -383,11 +393,9 @@ class AIMessageChunk(AIMessage):
"""Get the tool calls made by the AI."""
if self._tool_calls:
return self._tool_calls
tool_calls = [
block for block in self.content if block.get("type") == "tool_call"
]
tool_calls = [block for block in self.content if is_tool_call_block(block)]
if tool_calls:
self._tool_calls = cast("list[types.ToolCall]", tool_calls)
self._tool_calls = tool_calls
return self._tool_calls
@tool_calls.setter