fix(core): support interleaved content blocks in stream_v2 bridge

This commit is contained in:
Christian Bromann
2026-04-26 23:26:58 -07:00
parent 87ba30f097
commit 2e25005e9a
3 changed files with 175 additions and 58 deletions

View File

@@ -497,12 +497,12 @@ def chunks_to_events(
) -> Iterator[MessagesData]:
"""Convert a stream of `ChatGenerationChunk` to protocol events.
Blocks stream one at a time: when a chunk carries a different block
identifier than the currently-open one, the open block is finished
before the new block starts, matching the protocol's no-interleave
rule. Source-side identifiers (from the block's `index` field, which
may be int or string) are translated to sequential `uint` wire
indices.
Blocks are tracked independently by source-side identifier. Providers
such as Anthropic can interleave parallel tool-call chunks by index, so
each first-seen block gets a `content-block-start`, deltas keep their
stable wire index, and all open blocks are finalized at message end.
Source-side identifiers (from the block's `index` field, which may be
int or string) are translated to sequential `uint` wire indices.
Args:
chunks: Iterator of `ChatGenerationChunk` from `_stream()`.
@@ -512,9 +512,7 @@ def chunks_to_events(
`MessagesData` lifecycle events.
"""
started = False
open_key: Any = None
open_block: CompatBlock | None = None
open_wire_idx: int = 0
blocks: dict[Any, tuple[int, CompatBlock]] = {}
next_wire_idx = 0
usage: dict[str, Any] | None = None
response_metadata: dict[str, Any] = {}
@@ -545,24 +543,22 @@ def chunks_to_events(
yield _build_message_start(msg, message_id)
for key, block in _iter_protocol_blocks(msg):
if key != open_key:
if open_block is not None:
yield _finalize_and_build_finish(open_wire_idx, open_block)
open_key = key
open_wire_idx = next_wire_idx
if key not in blocks:
wire_idx = next_wire_idx
next_wire_idx += 1
open_block = dict(block)
blocks[key] = (wire_idx, dict(block))
yield ContentBlockStartData(
event="content-block-start",
index=open_wire_idx,
index=wire_idx,
content_block=_start_skeleton(block),
)
else:
open_block = _accumulate(open_block, block)
wire_idx, existing = blocks[key]
blocks[key] = (wire_idx, _accumulate(existing, block))
if _should_emit_delta(block):
yield ContentBlockDeltaData(
event="content-block-delta",
index=open_wire_idx,
index=wire_idx,
content_block=_to_protocol_block(block),
)
@@ -572,8 +568,8 @@ def chunks_to_events(
if not started:
return
if open_block is not None:
yield _finalize_and_build_finish(open_wire_idx, open_block)
for wire_idx, block in blocks.values():
yield _finalize_and_build_finish(wire_idx, block)
yield _build_message_finish(
usage=usage,
@@ -588,9 +584,7 @@ async def achunks_to_events(
) -> AsyncIterator[MessagesData]:
"""Async variant of `chunks_to_events`."""
started = False
open_key: Any = None
open_block: CompatBlock | None = None
open_wire_idx: int = 0
blocks: dict[Any, tuple[int, CompatBlock]] = {}
next_wire_idx = 0
usage: dict[str, Any] | None = None
response_metadata: dict[str, Any] = {}
@@ -615,24 +609,22 @@ async def achunks_to_events(
yield _build_message_start(msg, message_id)
for key, block in _iter_protocol_blocks(msg):
if key != open_key:
if open_block is not None:
yield _finalize_and_build_finish(open_wire_idx, open_block)
open_key = key
open_wire_idx = next_wire_idx
if key not in blocks:
wire_idx = next_wire_idx
next_wire_idx += 1
open_block = dict(block)
blocks[key] = (wire_idx, dict(block))
yield ContentBlockStartData(
event="content-block-start",
index=open_wire_idx,
index=wire_idx,
content_block=_start_skeleton(block),
)
else:
open_block = _accumulate(open_block, block)
wire_idx, existing = blocks[key]
blocks[key] = (wire_idx, _accumulate(existing, block))
if _should_emit_delta(block):
yield ContentBlockDeltaData(
event="content-block-delta",
index=open_wire_idx,
index=wire_idx,
content_block=_to_protocol_block(block),
)
@@ -642,8 +634,8 @@ async def achunks_to_events(
if not started:
return
if open_block is not None:
yield _finalize_and_build_finish(open_wire_idx, open_block)
for wire_idx, block in blocks.values():
yield _finalize_and_build_finish(wire_idx, block)
yield _build_message_finish(
usage=usage,

View File

@@ -1,5 +1,6 @@
"""Tests for the compat bridge (chunk-to-event conversion)."""
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Any, cast
import pytest
@@ -9,6 +10,7 @@ from langchain_core.language_models._compat_bridge import (
CompatBlock,
_finalize_block,
_to_protocol_usage,
achunks_to_events,
amessage_to_events,
chunks_to_events,
message_to_events,
@@ -143,14 +145,14 @@ def test_chunks_to_events_empty_iterator() -> None:
assert list(chunks_to_events(iter([]))) == []
def test_chunks_to_events_block_transitions_close_previous_block() -> None:
def test_chunks_to_events_block_transitions_keep_stable_indices() -> None:
"""String-keyed blocks that transition mid-stream each get their own lifecycle.
Regression test for OpenAI `responses/v1` style streams where
`content_blocks` uses string identifiers (e.g. `"lc_rs_305f30"`) to
distinguish blocks. Each distinct block must get its own
`content-block-start` / `content-block-finish` pair, with sequential
`uint` wire indices, and blocks must not interleave.
`uint` wire indices, and deltas keep that stable wire index.
"""
chunks = [
ChatGenerationChunk(
@@ -213,8 +215,8 @@ def test_chunks_to_events_block_transitions_close_previous_block() -> None:
assert [s["index"] for s in starts] == [0, 1, 2]
assert [f["index"] for f in finishes] == [0, 1, 2]
# Finish events must be interleaved with starts (no-interleave rule):
# block 0 finishes before block 1 starts, etc.
# Blocks are finalized at message end so providers can interleave
# deltas for parallel content blocks without closing them early.
events_any: list[Any] = events
lifecycle = [
(e["event"], e["index"])
@@ -223,10 +225,10 @@ def test_chunks_to_events_block_transitions_close_previous_block() -> None:
]
assert lifecycle == [
("content-block-start", 0),
("content-block-finish", 0),
("content-block-start", 1),
("content-block-finish", 1),
("content-block-start", 2),
("content-block-finish", 0),
("content-block-finish", 1),
("content-block-finish", 2),
]
@@ -297,6 +299,124 @@ def test_chunks_to_events_tool_call_multichunk() -> None:
)
def test_chunks_to_events_interleaved_parallel_tool_calls() -> None:
"""Parallel tool-call chunks can interleave without losing block lifecycles."""
events = list(
chunks_to_events(
iter(_interleaved_parallel_tool_call_chunks()), message_id="msg-1"
)
)
_assert_interleaved_parallel_tool_call_events(events)
@pytest.mark.asyncio
async def test_achunks_to_events_interleaved_parallel_tool_calls() -> None:
"""Async bridge preserves lifecycles for interleaved parallel tool calls."""
events = [
event
async for event in achunks_to_events(
_aiter_chunks(_interleaved_parallel_tool_call_chunks()),
message_id="msg-1",
)
]
_assert_interleaved_parallel_tool_call_events(events)
def _interleaved_parallel_tool_call_chunks() -> list[ChatGenerationChunk]:
return [
ChatGenerationChunk(
message=AIMessageChunk(
content="",
id="msg-1",
tool_call_chunks=[
{
"index": 0,
"id": "tc1",
"name": "task",
"args": '{"subagent_type": "haiku"',
"type": "tool_call_chunk",
}
],
)
),
ChatGenerationChunk(
message=AIMessageChunk(
content="",
id="msg-1",
tool_call_chunks=[
{
"index": 1,
"id": "tc2",
"name": "task",
"args": '{"subagent_type": "limerick"',
"type": "tool_call_chunk",
}
],
)
),
ChatGenerationChunk(
message=AIMessageChunk(
content="",
id="msg-1",
tool_call_chunks=[
{
"index": 0,
"id": None,
"name": None,
"args": ', "description": "Write a haiku"}',
"type": "tool_call_chunk",
}
],
)
),
ChatGenerationChunk(
message=AIMessageChunk(
content="",
id="msg-1",
tool_call_chunks=[
{
"index": 1,
"id": None,
"name": None,
"args": ', "description": "Write a limerick"}',
"type": "tool_call_chunk",
}
],
)
),
]
async def _aiter_chunks(
chunks: list[ChatGenerationChunk],
) -> AsyncIterator[ChatGenerationChunk]:
for chunk in chunks:
yield chunk
def _assert_interleaved_parallel_tool_call_events(events: list[Any]) -> None:
assert_valid_event_stream(events)
starts: list[Any] = [e for e in events if e["event"] == "content-block-start"]
finishes: list[Any] = [e for e in events if e["event"] == "content-block-finish"]
assert [s["index"] for s in starts] == [0, 1]
assert [f["index"] for f in finishes] == [0, 1]
finalized = [cast("ToolCall", event["content_block"]) for event in finishes]
assert finalized[0]["id"] == "tc1"
assert finalized[0]["args"] == {
"subagent_type": "haiku",
"description": "Write a haiku",
}
assert finalized[1]["id"] == "tc2"
assert finalized[1]["args"] == {
"subagent_type": "limerick",
"description": "Write a limerick",
}
def test_chunks_to_events_invalid_tool_call_keeps_stop_reason() -> None:
"""Malformed tool-args become invalid_tool_call; finish_reason stays `stop`."""
chunks = [

View File

@@ -5,9 +5,10 @@ or by the compat bridge's `chunks_to_events` / `message_to_events`)
conforms to the protocol lifecycle rules:
- `message-start` opens and `message-finish` closes the stream.
- Content blocks do not interleave: each block runs
- Content blocks may interleave: each block index runs
`content-block-start` → optional `content-block-delta`s →
`content-block-finish` before the next block begins.
`content-block-finish`, while other block indices may start or receive
deltas before that block finishes.
- Wire indices on content-block events are sequential `uint` values
starting at 0.
- For deltaable block types (`text`, `reasoning`, `tool_call_chunk`,
@@ -71,7 +72,7 @@ def assert_valid_event_stream(events: Iterable[Any]) -> None:
"`message-finish` must be the final event"
)
open_idx: int | None = None
open_indices: set[int] = set()
expected_next_idx = 0
start_events: dict[int, dict[str, Any]] = {}
finish_events: dict[int, dict[str, Any]] = {}
@@ -83,8 +84,9 @@ def assert_valid_event_stream(events: Iterable[Any]) -> None:
assert i == 0, f"duplicate `message-start` at event {i}"
continue
if ev == "message-finish":
assert open_idx is None, (
f"`message-finish` while block {open_idx} still open (event {i})"
assert not open_indices, (
f"`message-finish` while blocks {sorted(open_indices)} "
f"still open (event {i})"
)
continue
if ev == "error":
@@ -102,36 +104,39 @@ def assert_valid_event_stream(events: Iterable[Any]) -> None:
assert idx == expected_next_idx, (
f"expected next wire index {expected_next_idx}, got {idx} at event {i}"
)
assert open_idx is None, (
f"content-block-start at idx={idx} while block {open_idx} "
f"still open (event {i}); blocks must not interleave"
assert idx not in start_events, (
f"duplicate content-block-start for idx={idx} at event {i}"
)
open_idx = idx
open_indices.add(idx)
start_events[idx] = event["content_block"]
delta_accum[idx] = {}
expected_next_idx += 1
elif ev == "content-block-delta":
idx = event["index"]
assert idx == open_idx, (
f"content-block-delta at idx={idx} but currently-open block is "
f"{open_idx} (event {i})"
assert idx in open_indices, (
f"content-block-delta at idx={idx} but that block is not open "
f"(event {i})"
)
block = event["content_block"]
_accumulate_delta(delta_accum[idx], block)
elif ev == "content-block-finish":
idx = event["index"]
assert idx == open_idx, (
f"content-block-finish at idx={idx} but currently-open block is "
f"{open_idx} (event {i})"
assert idx in open_indices, (
f"content-block-finish at idx={idx} but that block is not open "
f"(event {i})"
)
assert idx not in finish_events, (
f"duplicate content-block-finish for idx={idx} at event {i}"
)
finish_events[idx] = event["content_block"]
open_idx = None
open_indices.remove(idx)
else:
# Unknown event types are accepted; the CDDL allows extensions.
continue
assert open_idx is None, (
f"block {open_idx} still open at end of stream — no content-block-finish"
assert not open_indices, (
f"blocks {sorted(open_indices)} still open at end of stream — "
"no content-block-finish"
)
missing = set(start_events) - set(finish_events)
assert not missing, (