From ff3153c04d484c018c43cf18154278184c3e3c43 Mon Sep 17 00:00:00 2001 From: ccurme Date: Mon, 4 Aug 2025 12:32:11 -0300 Subject: [PATCH] feat(core): move tool call chunks to content (v1) (#32358) --- .../language_models/v1/chat_models.py | 13 +- .../langchain_core/messages/content_blocks.py | 18 +- libs/core/langchain_core/messages/v1.py | 287 +++++++++--------- .../langchain_core/output_parsers/base.py | 2 +- .../langchain_core/output_parsers/json.py | 2 +- .../langchain_core/output_parsers/list.py | 4 +- .../language_models/chat_models/test_base.py | 96 +++++- .../output_parsers/test_base_parsers.py | 6 +- .../runnables/__snapshots__/test_graph.ambr | 15 + .../__snapshots__/test_runnable.ambr | 15 + libs/core/tests/unit_tests/test_messages.py | 36 +-- 11 files changed, 321 insertions(+), 173 deletions(-) diff --git a/libs/core/langchain_core/language_models/v1/chat_models.py b/libs/core/langchain_core/language_models/v1/chat_models.py index 775d58dd7a0..f8d83cb4bbe 100644 --- a/libs/core/langchain_core/language_models/v1/chat_models.py +++ b/libs/core/langchain_core/language_models/v1/chat_models.py @@ -458,7 +458,7 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC chunks: list[AIMessageChunkV1] = [] try: for msg in self._stream(input_messages, **kwargs): - run_manager.on_llm_new_token(msg.text or "") + run_manager.on_llm_new_token(msg.text) chunks.append(msg) except BaseException as e: run_manager.on_llm_error(e, response=_generate_response_from_error(e)) @@ -525,7 +525,7 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC chunks: list[AIMessageChunkV1] = [] try: async for msg in self._astream(input_messages, **kwargs): - await run_manager.on_llm_new_token(msg.text or "") + await run_manager.on_llm_new_token(msg.text) chunks.append(msg) except BaseException as e: await run_manager.on_llm_error( @@ -602,9 +602,12 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC # TODO: replace this with something for new messages input_messages = _normalize_messages_v1(messages) for msg in self._stream(input_messages, **kwargs): - run_manager.on_llm_new_token(msg.text or "") + run_manager.on_llm_new_token(msg.text) chunks.append(msg) yield msg + + if msg.chunk_position != "last": + yield (AIMessageChunkV1([], chunk_position="last")) except BaseException as e: run_manager.on_llm_error(e, response=_generate_response_from_error(e)) raise @@ -673,9 +676,11 @@ class BaseChatModelV1(RunnableSerializable[LanguageModelInput, AIMessageV1], ABC input_messages, **kwargs, ): - await run_manager.on_llm_new_token(msg.text or "") + await run_manager.on_llm_new_token(msg.text) chunks.append(msg) yield msg + if msg.chunk_position != "last": + yield (AIMessageChunkV1([], chunk_position="last")) except BaseException as e: await run_manager.on_llm_error(e, response=_generate_response_from_error(e)) raise diff --git a/libs/core/langchain_core/messages/content_blocks.py b/libs/core/langchain_core/messages/content_blocks.py index fedf4d92043..025f16eeb37 100644 --- a/libs/core/langchain_core/messages/content_blocks.py +++ b/libs/core/langchain_core/messages/content_blocks.py @@ -843,6 +843,7 @@ ContentBlock = Union[ TextContentBlock, ToolCall, InvalidToolCall, + ToolCallChunk, ReasoningContentBlock, NonStandardContentBlock, DataContentBlock, @@ -864,7 +865,22 @@ def _extract_typedict_type_values(union_type: Any) -> set[str]: KNOWN_BLOCK_TYPES = { - bt for bt in get_args(ContentBlock) for bt in get_args(bt.__annotations__["type"]) + "text", + "text-plain", + "tool_call", + "invalid_tool_call", + "tool_call_chunk", + "reasoning", + "non_standard", + "image", + "audio", + "file", + "video", + "code_interpreter_call", + "code_interpreter_output", + "code_interpreter_result", + "web_search_call", + "web_search_result", } diff --git a/libs/core/langchain_core/messages/v1.py b/libs/core/langchain_core/messages/v1.py index 79ebbf4a72c..1b779e4afd4 100644 --- a/libs/core/langchain_core/messages/v1.py +++ b/libs/core/langchain_core/messages/v1.py @@ -4,7 +4,6 @@ Each message has content that may be comprised of content blocks, defined under ``langchain_core.messages.content_blocks``. """ -import json import uuid from dataclasses import dataclass, field from typing import Any, Literal, Optional, Union, cast, get_args @@ -20,11 +19,9 @@ from langchain_core.messages.ai import ( add_usage, ) from langchain_core.messages.base import merge_content -from langchain_core.messages.tool import ToolCallChunk from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call from langchain_core.messages.tool import tool_call as create_tool_call -from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk -from langchain_core.utils._merge import merge_dicts, merge_lists +from langchain_core.utils._merge import merge_dicts from langchain_core.utils.json import parse_partial_json @@ -176,34 +173,55 @@ class AIMessage: if "id" in tool_call and tool_call["id"] in content_tool_calls: continue self.content.append(tool_call) + if invalid_tool_calls: + content_tool_calls = { + block["id"] + for block in self.content + if block["type"] == "invalid_tool_call" and "id" in block + } + for invalid_tool_call in invalid_tool_calls: + if ( + "id" in invalid_tool_call + and invalid_tool_call["id"] in content_tool_calls + ): + continue + self.content.append(invalid_tool_call) self._tool_calls = [ block for block in self.content if block["type"] == "tool_call" ] - self.invalid_tool_calls = invalid_tool_calls or [] + self._invalid_tool_calls = [ + block for block in self.content if block["type"] == "invalid_tool_call" + ] @property - def text(self) -> Optional[str]: + def text(self) -> str: """Extract all text content from the AI message as a string.""" text_blocks = [block for block in self.content if block["type"] == "text"] - if text_blocks: - return "".join(block["text"] for block in text_blocks) - return None + return "".join(block["text"] for block in text_blocks) @property - def tool_calls(self) -> list[types.ToolCall]: # update once we fix branch + def tool_calls(self) -> list[types.ToolCall]: """Get the tool calls made by the AI.""" - if self._tool_calls: - return self._tool_calls - tool_calls = [block for block in self.content if block["type"] == "tool_call"] - if tool_calls: - self._tool_calls = tool_calls - return [block for block in self.content if block["type"] == "tool_call"] + if not self._tool_calls: + self._tool_calls = [ + block for block in self.content if block["type"] == "tool_call" + ] + return self._tool_calls @tool_calls.setter def tool_calls(self, value: list[types.ToolCall]) -> None: """Set the tool calls for the AI message.""" self._tool_calls = value + @property + def invalid_tool_calls(self) -> list[types.InvalidToolCall]: + """Get the invalid tool calls made by the AI.""" + if not self._invalid_tool_calls: + self._invalid_tool_calls = [ + block for block in self.content if block["type"] == "invalid_tool_call" + ] + return self._invalid_tool_calls + @dataclass class AIMessageChunk(AIMessage): @@ -228,17 +246,10 @@ class AIMessageChunk(AIMessage): when deserializing messages. """ - tool_call_chunks: list[types.ToolCallChunk] = field(init=False) - """List of partial tool call data. - - Emitted by the model during streaming, this field contains - tool call chunks that may not yet be complete. It is used to reconstruct - tool calls from the streamed content. - """ - def __init__( self, content: Union[str, list[types.ContentBlock]], + *, id: Optional[str] = None, name: Optional[str] = None, lc_version: str = "v1", @@ -246,6 +257,7 @@ class AIMessageChunk(AIMessage): usage_metadata: Optional[UsageMetadata] = None, tool_call_chunks: Optional[list[types.ToolCallChunk]] = None, parsed: Optional[Union[dict[str, Any], BaseModel]] = None, + chunk_position: Optional[Literal["last"]] = None, ): """Initialize an AI message. @@ -258,6 +270,8 @@ class AIMessageChunk(AIMessage): usage_metadata: Optional metadata about token usage. tool_call_chunks: Optional list of partial tool call data. parsed: Optional auto-parsed message contents, if applicable. + chunk_position: Optional position of the chunk in the stream. If "last", + tool calls will be parsed when aggregated into a stream. """ if isinstance(content, str): self.content = [{"type": "text", "text": content, "index": 0}] @@ -269,112 +283,53 @@ class AIMessageChunk(AIMessage): self.lc_version = lc_version self.usage_metadata = usage_metadata self.parsed = parsed + self.chunk_position = chunk_position if response_metadata is None: self.response_metadata = {} else: self.response_metadata = response_metadata - if tool_call_chunks is None: - self.tool_call_chunks: list[types.ToolCallChunk] = [] - else: - self.tool_call_chunks = tool_call_chunks + + if tool_call_chunks: + content_tool_call_chunks = { + block["id"] + for block in self.content + if block.get("type") == "tool_call_chunk" and "id" in block + } + for chunk in tool_call_chunks: + if "id" in chunk and chunk["id"] in content_tool_call_chunks: + continue + self.content.append(chunk) + self._tool_call_chunks = [ + block for block in self.content if block.get("type") == "tool_call_chunk" + ] self._tool_calls: list[types.ToolCall] = [] - self.invalid_tool_calls: list[types.InvalidToolCall] = [] - self._init_tool_calls() - - def _init_tool_calls(self) -> None: - """Initialize tool calls from tool call chunks. - - Args: - values: The values to validate. - - Raises: - ValueError: If the tool call chunks are malformed. - """ - self._tool_calls = [] - self.invalid_tool_calls = [] - if not self.tool_call_chunks: - if self._tool_calls: - self.tool_call_chunks = [ - create_tool_call_chunk( - name=tc["name"], - args=json.dumps(tc["args"]), - id=tc["id"], - index=None, - ) - for tc in self._tool_calls - ] - if self.invalid_tool_calls: - tool_call_chunks = self.tool_call_chunks - tool_call_chunks.extend( - [ - create_tool_call_chunk( - name=tc["name"], args=tc["args"], id=tc["id"], index=None - ) - for tc in self.invalid_tool_calls - ] - ) - self.tool_call_chunks = tool_call_chunks - - tool_calls = [] - invalid_tool_calls = [] - - def add_chunk_to_invalid_tool_calls(chunk: ToolCallChunk) -> None: - invalid_tool_calls.append( - create_invalid_tool_call( - name=chunk.get("name", ""), - args=chunk.get("args", ""), - id=chunk.get("id", ""), - error=None, - ) - ) - - for chunk in self.tool_call_chunks: - try: - args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {} # type: ignore[arg-type] - if isinstance(args_, dict): - tool_calls.append( - create_tool_call( - name=chunk.get("name") or "", - args=args_, - id=chunk.get("id", ""), - ) - ) - else: - add_chunk_to_invalid_tool_calls(chunk) - except Exception: - add_chunk_to_invalid_tool_calls(chunk) - self._tool_calls = tool_calls - self.invalid_tool_calls = invalid_tool_calls + self._invalid_tool_calls: list[types.InvalidToolCall] = [] @property - def text(self) -> Optional[str]: - """Extract all text content from the AI message as a string.""" - text_blocks = [block for block in self.content if block["type"] == "text"] - if text_blocks: - return "".join(block["text"] for block in text_blocks) - return None - - @property - def reasoning(self) -> Optional[str]: - """Extract all reasoning text from the AI message as a string.""" - text_blocks = [ - block - for block in self.content - if block["type"] == "reasoning" and "reasoning" in block - ] - if text_blocks: - return "".join(block["reasoning"] for block in text_blocks) - return None + def tool_call_chunks(self) -> list[types.ToolCallChunk]: + """Get the tool calls made by the AI.""" + if not self._tool_call_chunks: + self._tool_call_chunks = [ + block + for block in self.content + if block.get("type") == "tool_call_chunk" + ] + return cast("list[types.ToolCallChunk]", self._tool_call_chunks) @property def tool_calls(self) -> list[types.ToolCall]: """Get the tool calls made by the AI.""" - if self._tool_calls: - return self._tool_calls - tool_calls = [block for block in self.content if block["type"] == "tool_call"] - if tool_calls: - self._tool_calls = tool_calls + if not self._tool_calls: + parsed_content = _init_tool_calls(self.content) + self._tool_calls = [ + block for block in parsed_content if block["type"] == "tool_call" + ] + self._invalid_tool_calls = [ + block + for block in parsed_content + if block["type"] == "invalid_tool_call" + ] return self._tool_calls @tool_calls.setter @@ -382,6 +337,21 @@ class AIMessageChunk(AIMessage): """Set the tool calls for the AI message.""" self._tool_calls = value + @property + def invalid_tool_calls(self) -> list[types.InvalidToolCall]: + """Get the invalid tool calls made by the AI.""" + if not self._invalid_tool_calls: + parsed_content = _init_tool_calls(self.content) + self._tool_calls = [ + block for block in parsed_content if block["type"] == "tool_call" + ] + self._invalid_tool_calls = [ + block + for block in parsed_content + if block["type"] == "invalid_tool_call" + ] + return self._invalid_tool_calls + def __add__(self, other: Any) -> "AIMessageChunk": """Add AIMessageChunk to this one.""" if isinstance(other, AIMessageChunk): @@ -396,49 +366,76 @@ class AIMessageChunk(AIMessage): def to_message(self) -> "AIMessage": """Convert this AIMessageChunk to an AIMessage.""" return AIMessage( - content=self.content, + content=_init_tool_calls(self.content), id=self.id, name=self.name, lc_version=self.lc_version, response_metadata=self.response_metadata, usage_metadata=self.usage_metadata, - tool_calls=self.tool_calls, - invalid_tool_calls=self.invalid_tool_calls, parsed=self.parsed, ) +def _init_tool_calls(content: list[types.ContentBlock]) -> list[types.ContentBlock]: + """Parse tool call chunks in content into tool calls.""" + new_content = [] + for block in content: + if block.get("type") != "tool_call_chunk": + new_content.append(block) + continue + try: + args_ = ( + parse_partial_json(cast("str", block.get("args") or "")) + if block.get("args") + else {} + ) + if isinstance(args_, dict): + new_content.append( + create_tool_call( + name=cast("str", block.get("name") or ""), + args=args_, + id=cast("str", block.get("id", "")), + ) + ) + else: + new_content.append( + create_invalid_tool_call( + name=cast("str", block.get("name", "")), + args=cast("str", block.get("args", "")), + id=cast("str", block.get("id", "")), + error=None, + ) + ) + except Exception: + new_content.append( + create_invalid_tool_call( + name=cast("str", block.get("name", "")), + args=cast("str", block.get("args", "")), + id=cast("str", block.get("id", "")), + error=None, + ) + ) + return new_content + + def add_ai_message_chunks( left: AIMessageChunk, *others: AIMessageChunk ) -> AIMessageChunk: """Add multiple AIMessageChunks together.""" if not others: return left - content = merge_content( - cast("list[str | dict[Any, Any]]", left.content), - *(cast("list[str | dict[Any, Any]]", o.content) for o in others), + content = cast( + "list[types.ContentBlock]", + merge_content( + cast("list[str | dict[Any, Any]]", left.content), + *(cast("list[str | dict[Any, Any]]", o.content) for o in others), + ), ) response_metadata = merge_dicts( cast("dict", left.response_metadata), *(cast("dict", o.response_metadata) for o in others), ) - # Merge tool call chunks - if raw_tool_calls := merge_lists( - left.tool_call_chunks, *(o.tool_call_chunks for o in others) - ): - tool_call_chunks = [ - create_tool_call_chunk( - name=rtc.get("name"), - args=rtc.get("args"), - index=rtc.get("index"), - id=rtc.get("id"), - ) - for rtc in raw_tool_calls - ] - else: - tool_call_chunks = [] - # Token usage if left.usage_metadata or any(o.usage_metadata is not None for o in others): usage_metadata: Optional[UsageMetadata] = left.usage_metadata @@ -480,13 +477,19 @@ def add_ai_message_chunks( chunk_id = id_ break + chunk_position: Optional[Literal["last"]] = ( + "last" if any(x.chunk_position == "last" for x in [left, *others]) else None + ) + if chunk_position == "last": + content = _init_tool_calls(content) + return left.__class__( - content=cast("list[types.ContentBlock]", content), - tool_call_chunks=tool_call_chunks, + content=content, response_metadata=cast("ResponseMetadata", response_metadata), usage_metadata=usage_metadata, parsed=parsed, id=chunk_id, + chunk_position=chunk_position, ) diff --git a/libs/core/langchain_core/output_parsers/base.py b/libs/core/langchain_core/output_parsers/base.py index 6ce45f94590..b91ac475842 100644 --- a/libs/core/langchain_core/output_parsers/base.py +++ b/libs/core/langchain_core/output_parsers/base.py @@ -283,7 +283,7 @@ class BaseOutputParser( Structured output. """ if isinstance(result, AIMessage): - return self.parse(result.text or "") + return self.parse(result.text) return self.parse(result[0].text) @abstractmethod diff --git a/libs/core/langchain_core/output_parsers/json.py b/libs/core/langchain_core/output_parsers/json.py index c9465d690a1..1384774afff 100644 --- a/libs/core/langchain_core/output_parsers/json.py +++ b/libs/core/langchain_core/output_parsers/json.py @@ -73,7 +73,7 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]): Raises: OutputParserException: If the output is not valid JSON. """ - text = result.text or "" if isinstance(result, AIMessage) else result[0].text + text = result.text if isinstance(result, AIMessage) else result[0].text text = text.strip() if partial: try: diff --git a/libs/core/langchain_core/output_parsers/list.py b/libs/core/langchain_core/output_parsers/list.py index a6ca2511c85..ce825ee2b7d 100644 --- a/libs/core/langchain_core/output_parsers/list.py +++ b/libs/core/langchain_core/output_parsers/list.py @@ -83,7 +83,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]): continue buffer += chunk_content elif isinstance(chunk, AIMessage): - buffer += chunk.text or "" + buffer += chunk.text else: # add current chunk to buffer buffer += chunk @@ -119,7 +119,7 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]): continue buffer += chunk_content elif isinstance(chunk, AIMessage): - buffer += chunk.text or "" + buffer += chunk.text else: # add current chunk to buffer buffer += chunk diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index 37b05ed8255..e653d35fda8 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -14,7 +14,10 @@ from langchain_core.language_models import ( ParrotFakeChatModel, ) from langchain_core.language_models._utils import _normalize_messages -from langchain_core.language_models.fake_chat_models import FakeListChatModelError +from langchain_core.language_models.fake_chat_models import ( + FakeListChatModelError, + GenericFakeChatModelV1, +) from langchain_core.messages import ( AIMessage, AIMessageChunk, @@ -22,6 +25,7 @@ from langchain_core.messages import ( HumanMessage, SystemMessage, ) +from langchain_core.messages.v1 import AIMessageChunk as AIMessageChunkV1 from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs.llm_result import LLMResult from langchain_core.tracers import LogStreamCallbackHandler @@ -654,3 +658,93 @@ def test_normalize_messages_edge_cases() -> None: ) ] assert messages == _normalize_messages(messages) + + +def test_streaming_v1() -> None: + chunks = [ + AIMessageChunkV1( + [ + { + "type": "reasoning", + "reasoning": "Let's call a tool.", + "index": 0, + } + ] + ), + AIMessageChunkV1( + [], + tool_call_chunks=[ + { + "type": "tool_call_chunk", + "args": "", + "name": "tool_name", + "id": "call_123", + "index": 1, + }, + ], + ), + AIMessageChunkV1( + [], + tool_call_chunks=[ + { + "type": "tool_call_chunk", + "args": '{"a', + "name": "", + "id": "", + "index": 1, + }, + ], + ), + AIMessageChunkV1( + [], + tool_call_chunks=[ + { + "type": "tool_call_chunk", + "args": '": 1}', + "name": "", + "id": "", + "index": 1, + }, + ], + ), + ] + full: Optional[AIMessageChunkV1] = None + for chunk in chunks: + full = chunk if full is None else full + chunk + + assert isinstance(full, AIMessageChunkV1) + assert full.content == [ + { + "type": "reasoning", + "reasoning": "Let's call a tool.", + "index": 0, + }, + { + "type": "tool_call_chunk", + "args": '{"a": 1}', + "name": "tool_name", + "id": "call_123", + "index": 1, + }, + ] + + llm = GenericFakeChatModelV1(message_chunks=chunks) + + full = None + for chunk in llm.stream("anything"): + full = chunk if full is None else full + chunk + + assert isinstance(full, AIMessageChunkV1) + assert full.content == [ + { + "type": "reasoning", + "reasoning": "Let's call a tool.", + "index": 0, + }, + { + "type": "tool_call", + "args": {"a": 1}, + "name": "tool_name", + "id": "call_123", + }, + ] diff --git a/libs/core/tests/unit_tests/output_parsers/test_base_parsers.py b/libs/core/tests/unit_tests/output_parsers/test_base_parsers.py index fc687629f01..8794d3d2f65 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_base_parsers.py +++ b/libs/core/tests/unit_tests/output_parsers/test_base_parsers.py @@ -37,7 +37,7 @@ def test_base_generation_parser() -> None: that support streaming """ if isinstance(result, AIMessageV1): - content = result.text or "" + content = result.text else: if len(result) != 1: msg = ( @@ -89,7 +89,7 @@ def test_base_transform_output_parser() -> None: that support streaming """ if isinstance(result, AIMessageV1): - content = result.text or "" + content = result.text else: if len(result) != 1: msg = ( @@ -116,4 +116,4 @@ def test_base_transform_output_parser() -> None: model_v1 = GenericFakeChatModelV1(message_chunks=["hello", " ", "world"]) chain_v1 = model_v1 | StrInvertCase() chunks = list(chain_v1.stream("")) - assert chunks == ["HELLO", " ", "WORLD"] + assert chunks == ["HELLO", " ", "WORLD", ""] diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr index c600f508299..acce9007092 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr @@ -2543,6 +2543,9 @@ dict({ '$ref': '#/$defs/InvalidToolCall', }), + dict({ + '$ref': '#/$defs/ToolCallChunk', + }), dict({ '$ref': '#/$defs/ReasoningContentBlock', }), @@ -2666,6 +2669,9 @@ dict({ '$ref': '#/$defs/InvalidToolCall', }), + dict({ + '$ref': '#/$defs/ToolCallChunk', + }), dict({ '$ref': '#/$defs/ReasoningContentBlock', }), @@ -2789,6 +2795,9 @@ dict({ '$ref': '#/$defs/InvalidToolCall', }), + dict({ + '$ref': '#/$defs/ToolCallChunk', + }), dict({ '$ref': '#/$defs/ReasoningContentBlock', }), @@ -2874,6 +2883,9 @@ dict({ '$ref': '#/$defs/InvalidToolCall', }), + dict({ + '$ref': '#/$defs/ToolCallChunk', + }), dict({ '$ref': '#/$defs/ReasoningContentBlock', }), @@ -2982,6 +2994,9 @@ dict({ '$ref': '#/$defs/InvalidToolCall', }), + dict({ + '$ref': '#/$defs/ToolCallChunk', + }), dict({ '$ref': '#/$defs/ReasoningContentBlock', }), diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr index ea62b19f550..fdc3ca0c813 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr @@ -11535,6 +11535,9 @@ dict({ '$ref': '#/definitions/InvalidToolCall', }), + dict({ + '$ref': '#/definitions/ToolCallChunk', + }), dict({ '$ref': '#/definitions/ReasoningContentBlock', }), @@ -11657,6 +11660,9 @@ dict({ '$ref': '#/definitions/InvalidToolCall', }), + dict({ + '$ref': '#/definitions/ToolCallChunk', + }), dict({ '$ref': '#/definitions/ReasoningContentBlock', }), @@ -11779,6 +11785,9 @@ dict({ '$ref': '#/definitions/InvalidToolCall', }), + dict({ + '$ref': '#/definitions/ToolCallChunk', + }), dict({ '$ref': '#/definitions/ReasoningContentBlock', }), @@ -11863,6 +11872,9 @@ dict({ '$ref': '#/definitions/InvalidToolCall', }), + dict({ + '$ref': '#/definitions/ToolCallChunk', + }), dict({ '$ref': '#/definitions/ReasoningContentBlock', }), @@ -11970,6 +11982,9 @@ dict({ '$ref': '#/definitions/InvalidToolCall', }), + dict({ + '$ref': '#/definitions/ToolCallChunk', + }), dict({ '$ref': '#/definitions/ReasoningContentBlock', }), diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index 77ddf96c974..5da1ad20f5d 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -3,6 +3,7 @@ import uuid from typing import Optional, Union import pytest +from typing_extensions import get_args from langchain_core.documents import Document from langchain_core.load import dumpd, load @@ -30,7 +31,7 @@ from langchain_core.messages import ( messages_from_dict, messages_to_dict, ) -from langchain_core.messages.content_blocks import KNOWN_BLOCK_TYPES +from langchain_core.messages.content_blocks import KNOWN_BLOCK_TYPES, ContentBlock from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call from langchain_core.messages.tool import tool_call as create_tool_call from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk @@ -1363,20 +1364,19 @@ def test_convert_to_openai_image_block() -> None: def test_known_block_types() -> None: - assert { - "text", - "text-plain", - "tool_call", - "invalid_tool_call", - "reasoning", - "non_standard", - "image", - "audio", - "file", - "video", - "code_interpreter_call", - "code_interpreter_output", - "code_interpreter_result", - "web_search_call", - "web_search_result", - } == KNOWN_BLOCK_TYPES + expected = { + bt + for bt in get_args(ContentBlock) + for bt in get_args(bt.__annotations__["type"]) + } + # Normalize any Literal[...] types in block types to their string values. + # This ensures all entries are plain strings, not Literal objects. + expected = { + t + if isinstance(t, str) + else t.__args__[0] + if hasattr(t, "__args__") and len(t.__args__) == 1 + else t + for t in expected + } + assert expected == KNOWN_BLOCK_TYPES