mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 23:26:34 +00:00
type guards to remove casting
This commit is contained in:
parent
822dd5075c
commit
a6686d7c4f
@ -7,7 +7,7 @@ Each message has content that may be comprised of content blocks, defined under
|
|||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Literal, Optional, Union, cast, get_args
|
from typing import Any, Literal, Optional, TypeGuard, Union, cast, get_args
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
@ -29,6 +29,23 @@ from langchain_core.utils._merge import merge_dicts, merge_lists
|
|||||||
from langchain_core.utils.json import parse_partial_json
|
from langchain_core.utils.json import parse_partial_json
|
||||||
|
|
||||||
|
|
||||||
|
def is_tool_call_block(block: types.ContentBlock) -> TypeGuard[types.ToolCall]:
|
||||||
|
"""Type guard to check if a content block is a tool call."""
|
||||||
|
return block.get("type") == "tool_call"
|
||||||
|
|
||||||
|
|
||||||
|
def is_text_block(block: types.ContentBlock) -> TypeGuard[types.TextContentBlock]:
|
||||||
|
"""Type guard to check if a content block is a text block."""
|
||||||
|
return block.get("type") == "text"
|
||||||
|
|
||||||
|
|
||||||
|
def is_invalid_tool_call_block(
|
||||||
|
block: types.ContentBlock,
|
||||||
|
) -> TypeGuard[types.InvalidToolCall]:
|
||||||
|
"""Type guard to check if a content block is an invalid tool call."""
|
||||||
|
return block.get("type") == "invalid_tool_call"
|
||||||
|
|
||||||
|
|
||||||
def _ensure_id(id_val: Optional[str]) -> str:
|
def _ensure_id(id_val: Optional[str]) -> str:
|
||||||
"""Ensure the ID is a valid string, generating a new UUID if not provided.
|
"""Ensure the ID is a valid string, generating a new UUID if not provided.
|
||||||
|
|
||||||
@ -177,20 +194,17 @@ class AIMessage:
|
|||||||
if "id" in tool_call and tool_call["id"] in content_tool_calls:
|
if "id" in tool_call and tool_call["id"] in content_tool_calls:
|
||||||
continue
|
continue
|
||||||
self.content.append(tool_call)
|
self.content.append(tool_call)
|
||||||
self._tool_calls: list[types.ToolCall] = cast(
|
self._tool_calls: list[types.ToolCall] = [
|
||||||
"list[types.ToolCall]",
|
block for block in self.content if is_tool_call_block(block)
|
||||||
[block for block in self.content if block.get("type") == "tool_call"],
|
]
|
||||||
)
|
|
||||||
self.invalid_tool_calls = invalid_tool_calls or []
|
self.invalid_tool_calls = invalid_tool_calls or []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self) -> Optional[str]:
|
def text(self) -> Optional[str]:
|
||||||
"""Extract all text content from the AI message as a string."""
|
"""Extract all text content from the AI message as a string."""
|
||||||
text_blocks = [block for block in self.content if block.get("type") == "text"]
|
text_blocks = [block for block in self.content if is_text_block(block)]
|
||||||
if text_blocks:
|
if text_blocks:
|
||||||
return "".join(
|
return "".join(block["text"] for block in text_blocks)
|
||||||
cast("types.TextContentBlock", block)["text"] for block in text_blocks
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -198,11 +212,9 @@ class AIMessage:
|
|||||||
"""Get the tool calls made by the AI."""
|
"""Get the tool calls made by the AI."""
|
||||||
if self._tool_calls:
|
if self._tool_calls:
|
||||||
return self._tool_calls
|
return self._tool_calls
|
||||||
tool_calls = [
|
tool_calls = [block for block in self.content if is_tool_call_block(block)]
|
||||||
block for block in self.content if block.get("type") == "tool_call"
|
|
||||||
]
|
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
self._tool_calls = cast("list[types.ToolCall]", tool_calls)
|
self._tool_calls = tool_calls
|
||||||
return self._tool_calls
|
return self._tool_calls
|
||||||
|
|
||||||
@tool_calls.setter
|
@tool_calls.setter
|
||||||
@ -356,11 +368,9 @@ class AIMessageChunk(AIMessage):
|
|||||||
@property
|
@property
|
||||||
def text(self) -> Optional[str]:
|
def text(self) -> Optional[str]:
|
||||||
"""Extract all text content from the AI message as a string."""
|
"""Extract all text content from the AI message as a string."""
|
||||||
text_blocks = [block for block in self.content if block.get("type") == "text"]
|
text_blocks = [block for block in self.content if is_text_block(block)]
|
||||||
if text_blocks:
|
if text_blocks:
|
||||||
return "".join(
|
return "".join(block["text"] for block in text_blocks)
|
||||||
cast("types.TextContentBlock", block)["text"] for block in text_blocks
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -383,11 +393,9 @@ class AIMessageChunk(AIMessage):
|
|||||||
"""Get the tool calls made by the AI."""
|
"""Get the tool calls made by the AI."""
|
||||||
if self._tool_calls:
|
if self._tool_calls:
|
||||||
return self._tool_calls
|
return self._tool_calls
|
||||||
tool_calls = [
|
tool_calls = [block for block in self.content if is_tool_call_block(block)]
|
||||||
block for block in self.content if block.get("type") == "tool_call"
|
|
||||||
]
|
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
self._tool_calls = cast("list[types.ToolCall]", tool_calls)
|
self._tool_calls = tool_calls
|
||||||
return self._tool_calls
|
return self._tool_calls
|
||||||
|
|
||||||
@tool_calls.setter
|
@tool_calls.setter
|
||||||
|
Loading…
Reference in New Issue
Block a user