mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 15:16:21 +00:00
type guards to remove casting
This commit is contained in:
parent
822dd5075c
commit
a6686d7c4f
@ -7,7 +7,7 @@ Each message has content that may be comprised of content blocks, defined under
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, Optional, Union, cast, get_args
|
||||
from typing import Any, Literal, Optional, TypeGuard, Union, cast, get_args
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
@ -29,6 +29,23 @@ from langchain_core.utils._merge import merge_dicts, merge_lists
|
||||
from langchain_core.utils.json import parse_partial_json
|
||||
|
||||
|
||||
def is_tool_call_block(block: types.ContentBlock) -> TypeGuard[types.ToolCall]:
|
||||
"""Type guard to check if a content block is a tool call."""
|
||||
return block.get("type") == "tool_call"
|
||||
|
||||
|
||||
def is_text_block(block: types.ContentBlock) -> TypeGuard[types.TextContentBlock]:
|
||||
"""Type guard to check if a content block is a text block."""
|
||||
return block.get("type") == "text"
|
||||
|
||||
|
||||
def is_invalid_tool_call_block(
|
||||
block: types.ContentBlock,
|
||||
) -> TypeGuard[types.InvalidToolCall]:
|
||||
"""Type guard to check if a content block is an invalid tool call."""
|
||||
return block.get("type") == "invalid_tool_call"
|
||||
|
||||
|
||||
def _ensure_id(id_val: Optional[str]) -> str:
|
||||
"""Ensure the ID is a valid string, generating a new UUID if not provided.
|
||||
|
||||
@ -177,20 +194,17 @@ class AIMessage:
|
||||
if "id" in tool_call and tool_call["id"] in content_tool_calls:
|
||||
continue
|
||||
self.content.append(tool_call)
|
||||
self._tool_calls: list[types.ToolCall] = cast(
|
||||
"list[types.ToolCall]",
|
||||
[block for block in self.content if block.get("type") == "tool_call"],
|
||||
)
|
||||
self._tool_calls: list[types.ToolCall] = [
|
||||
block for block in self.content if is_tool_call_block(block)
|
||||
]
|
||||
self.invalid_tool_calls = invalid_tool_calls or []
|
||||
|
||||
@property
|
||||
def text(self) -> Optional[str]:
|
||||
"""Extract all text content from the AI message as a string."""
|
||||
text_blocks = [block for block in self.content if block.get("type") == "text"]
|
||||
text_blocks = [block for block in self.content if is_text_block(block)]
|
||||
if text_blocks:
|
||||
return "".join(
|
||||
cast("types.TextContentBlock", block)["text"] for block in text_blocks
|
||||
)
|
||||
return "".join(block["text"] for block in text_blocks)
|
||||
return None
|
||||
|
||||
@property
|
||||
@ -198,11 +212,9 @@ class AIMessage:
|
||||
"""Get the tool calls made by the AI."""
|
||||
if self._tool_calls:
|
||||
return self._tool_calls
|
||||
tool_calls = [
|
||||
block for block in self.content if block.get("type") == "tool_call"
|
||||
]
|
||||
tool_calls = [block for block in self.content if is_tool_call_block(block)]
|
||||
if tool_calls:
|
||||
self._tool_calls = cast("list[types.ToolCall]", tool_calls)
|
||||
self._tool_calls = tool_calls
|
||||
return self._tool_calls
|
||||
|
||||
@tool_calls.setter
|
||||
@ -356,11 +368,9 @@ class AIMessageChunk(AIMessage):
|
||||
@property
|
||||
def text(self) -> Optional[str]:
|
||||
"""Extract all text content from the AI message as a string."""
|
||||
text_blocks = [block for block in self.content if block.get("type") == "text"]
|
||||
text_blocks = [block for block in self.content if is_text_block(block)]
|
||||
if text_blocks:
|
||||
return "".join(
|
||||
cast("types.TextContentBlock", block)["text"] for block in text_blocks
|
||||
)
|
||||
return "".join(block["text"] for block in text_blocks)
|
||||
return None
|
||||
|
||||
@property
|
||||
@ -383,11 +393,9 @@ class AIMessageChunk(AIMessage):
|
||||
"""Get the tool calls made by the AI."""
|
||||
if self._tool_calls:
|
||||
return self._tool_calls
|
||||
tool_calls = [
|
||||
block for block in self.content if block.get("type") == "tool_call"
|
||||
]
|
||||
tool_calls = [block for block in self.content if is_tool_call_block(block)]
|
||||
if tool_calls:
|
||||
self._tool_calls = cast("list[types.ToolCall]", tool_calls)
|
||||
self._tool_calls = tool_calls
|
||||
return self._tool_calls
|
||||
|
||||
@tool_calls.setter
|
||||
|
Loading…
Reference in New Issue
Block a user