diff --git a/libs/core/langchain_core/messages/v1.py b/libs/core/langchain_core/messages/v1.py index e68db4b6a47..66063813846 100644 --- a/libs/core/langchain_core/messages/v1.py +++ b/libs/core/langchain_core/messages/v1.py @@ -7,7 +7,7 @@ Each message has content that may be comprised of content blocks, defined under import json import uuid from dataclasses import dataclass, field -from typing import Any, Literal, Optional, Union, cast, get_args +from typing import Any, Literal, Optional, TypeGuard, Union, cast, get_args from pydantic import BaseModel from typing_extensions import TypedDict @@ -29,6 +29,23 @@ from langchain_core.utils._merge import merge_dicts, merge_lists from langchain_core.utils.json import parse_partial_json +def is_tool_call_block(block: types.ContentBlock) -> TypeGuard[types.ToolCall]: + """Type guard to check if a content block is a tool call.""" + return block.get("type") == "tool_call" + + +def is_text_block(block: types.ContentBlock) -> TypeGuard[types.TextContentBlock]: + """Type guard to check if a content block is a text block.""" + return block.get("type") == "text" + + +def is_invalid_tool_call_block( + block: types.ContentBlock, +) -> TypeGuard[types.InvalidToolCall]: + """Type guard to check if a content block is an invalid tool call.""" + return block.get("type") == "invalid_tool_call" + + def _ensure_id(id_val: Optional[str]) -> str: """Ensure the ID is a valid string, generating a new UUID if not provided. @@ -177,20 +194,17 @@ class AIMessage: if "id" in tool_call and tool_call["id"] in content_tool_calls: continue self.content.append(tool_call) - self._tool_calls: list[types.ToolCall] = cast( - "list[types.ToolCall]", - [block for block in self.content if block.get("type") == "tool_call"], - ) + self._tool_calls: list[types.ToolCall] = [ + block for block in self.content if is_tool_call_block(block) + ] self.invalid_tool_calls = invalid_tool_calls or [] @property def text(self) -> Optional[str]: """Extract all text content from the AI message as a string.""" - text_blocks = [block for block in self.content if block.get("type") == "text"] + text_blocks = [block for block in self.content if is_text_block(block)] if text_blocks: - return "".join( - cast("types.TextContentBlock", block)["text"] for block in text_blocks - ) + return "".join(block["text"] for block in text_blocks) return None @property @@ -198,11 +212,9 @@ class AIMessage: """Get the tool calls made by the AI.""" if self._tool_calls: return self._tool_calls - tool_calls = [ - block for block in self.content if block.get("type") == "tool_call" - ] + tool_calls = [block for block in self.content if is_tool_call_block(block)] if tool_calls: - self._tool_calls = cast("list[types.ToolCall]", tool_calls) + self._tool_calls = tool_calls return self._tool_calls @tool_calls.setter @@ -356,11 +368,9 @@ class AIMessageChunk(AIMessage): @property def text(self) -> Optional[str]: """Extract all text content from the AI message as a string.""" - text_blocks = [block for block in self.content if block.get("type") == "text"] + text_blocks = [block for block in self.content if is_text_block(block)] if text_blocks: - return "".join( - cast("types.TextContentBlock", block)["text"] for block in text_blocks - ) + return "".join(block["text"] for block in text_blocks) return None @property @@ -383,11 +393,9 @@ class AIMessageChunk(AIMessage): """Get the tool calls made by the AI.""" if self._tool_calls: return self._tool_calls - tool_calls = [ - block for block in self.content if block.get("type") == "tool_call" - ] + tool_calls = [block for block in self.content if is_tool_call_block(block)] if tool_calls: - self._tool_calls = cast("list[types.ToolCall]", tool_calls) + self._tool_calls = tool_calls return self._tool_calls @tool_calls.setter