mirror of
https://github.com/hwchase17/langchain.git
synced 2026-05-03 01:46:42 +00:00
fix(core): support interleaved content blocks in stream_v2 bridge
This commit is contained in:
@@ -497,12 +497,12 @@ def chunks_to_events(
|
||||
) -> Iterator[MessagesData]:
|
||||
"""Convert a stream of `ChatGenerationChunk` to protocol events.
|
||||
|
||||
Blocks stream one at a time: when a chunk carries a different block
|
||||
identifier than the currently-open one, the open block is finished
|
||||
before the new block starts, matching the protocol's no-interleave
|
||||
rule. Source-side identifiers (from the block's `index` field, which
|
||||
may be int or string) are translated to sequential `uint` wire
|
||||
indices.
|
||||
Blocks are tracked independently by source-side identifier. Providers
|
||||
such as Anthropic can interleave parallel tool-call chunks by index, so
|
||||
each first-seen block gets a `content-block-start`, deltas keep their
|
||||
stable wire index, and all open blocks are finalized at message end.
|
||||
Source-side identifiers (from the block's `index` field, which may be
|
||||
int or string) are translated to sequential `uint` wire indices.
|
||||
|
||||
Args:
|
||||
chunks: Iterator of `ChatGenerationChunk` from `_stream()`.
|
||||
@@ -512,9 +512,7 @@ def chunks_to_events(
|
||||
`MessagesData` lifecycle events.
|
||||
"""
|
||||
started = False
|
||||
open_key: Any = None
|
||||
open_block: CompatBlock | None = None
|
||||
open_wire_idx: int = 0
|
||||
blocks: dict[Any, tuple[int, CompatBlock]] = {}
|
||||
next_wire_idx = 0
|
||||
usage: dict[str, Any] | None = None
|
||||
response_metadata: dict[str, Any] = {}
|
||||
@@ -545,24 +543,22 @@ def chunks_to_events(
|
||||
yield _build_message_start(msg, message_id)
|
||||
|
||||
for key, block in _iter_protocol_blocks(msg):
|
||||
if key != open_key:
|
||||
if open_block is not None:
|
||||
yield _finalize_and_build_finish(open_wire_idx, open_block)
|
||||
open_key = key
|
||||
open_wire_idx = next_wire_idx
|
||||
if key not in blocks:
|
||||
wire_idx = next_wire_idx
|
||||
next_wire_idx += 1
|
||||
open_block = dict(block)
|
||||
blocks[key] = (wire_idx, dict(block))
|
||||
yield ContentBlockStartData(
|
||||
event="content-block-start",
|
||||
index=open_wire_idx,
|
||||
index=wire_idx,
|
||||
content_block=_start_skeleton(block),
|
||||
)
|
||||
else:
|
||||
open_block = _accumulate(open_block, block)
|
||||
wire_idx, existing = blocks[key]
|
||||
blocks[key] = (wire_idx, _accumulate(existing, block))
|
||||
if _should_emit_delta(block):
|
||||
yield ContentBlockDeltaData(
|
||||
event="content-block-delta",
|
||||
index=open_wire_idx,
|
||||
index=wire_idx,
|
||||
content_block=_to_protocol_block(block),
|
||||
)
|
||||
|
||||
@@ -572,8 +568,8 @@ def chunks_to_events(
|
||||
if not started:
|
||||
return
|
||||
|
||||
if open_block is not None:
|
||||
yield _finalize_and_build_finish(open_wire_idx, open_block)
|
||||
for wire_idx, block in blocks.values():
|
||||
yield _finalize_and_build_finish(wire_idx, block)
|
||||
|
||||
yield _build_message_finish(
|
||||
usage=usage,
|
||||
@@ -588,9 +584,7 @@ async def achunks_to_events(
|
||||
) -> AsyncIterator[MessagesData]:
|
||||
"""Async variant of `chunks_to_events`."""
|
||||
started = False
|
||||
open_key: Any = None
|
||||
open_block: CompatBlock | None = None
|
||||
open_wire_idx: int = 0
|
||||
blocks: dict[Any, tuple[int, CompatBlock]] = {}
|
||||
next_wire_idx = 0
|
||||
usage: dict[str, Any] | None = None
|
||||
response_metadata: dict[str, Any] = {}
|
||||
@@ -615,24 +609,22 @@ async def achunks_to_events(
|
||||
yield _build_message_start(msg, message_id)
|
||||
|
||||
for key, block in _iter_protocol_blocks(msg):
|
||||
if key != open_key:
|
||||
if open_block is not None:
|
||||
yield _finalize_and_build_finish(open_wire_idx, open_block)
|
||||
open_key = key
|
||||
open_wire_idx = next_wire_idx
|
||||
if key not in blocks:
|
||||
wire_idx = next_wire_idx
|
||||
next_wire_idx += 1
|
||||
open_block = dict(block)
|
||||
blocks[key] = (wire_idx, dict(block))
|
||||
yield ContentBlockStartData(
|
||||
event="content-block-start",
|
||||
index=open_wire_idx,
|
||||
index=wire_idx,
|
||||
content_block=_start_skeleton(block),
|
||||
)
|
||||
else:
|
||||
open_block = _accumulate(open_block, block)
|
||||
wire_idx, existing = blocks[key]
|
||||
blocks[key] = (wire_idx, _accumulate(existing, block))
|
||||
if _should_emit_delta(block):
|
||||
yield ContentBlockDeltaData(
|
||||
event="content-block-delta",
|
||||
index=open_wire_idx,
|
||||
index=wire_idx,
|
||||
content_block=_to_protocol_block(block),
|
||||
)
|
||||
|
||||
@@ -642,8 +634,8 @@ async def achunks_to_events(
|
||||
if not started:
|
||||
return
|
||||
|
||||
if open_block is not None:
|
||||
yield _finalize_and_build_finish(open_wire_idx, open_block)
|
||||
for wire_idx, block in blocks.values():
|
||||
yield _finalize_and_build_finish(wire_idx, block)
|
||||
|
||||
yield _build_message_finish(
|
||||
usage=usage,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for the compat bridge (chunk-to-event conversion)."""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import pytest
|
||||
@@ -9,6 +10,7 @@ from langchain_core.language_models._compat_bridge import (
|
||||
CompatBlock,
|
||||
_finalize_block,
|
||||
_to_protocol_usage,
|
||||
achunks_to_events,
|
||||
amessage_to_events,
|
||||
chunks_to_events,
|
||||
message_to_events,
|
||||
@@ -143,14 +145,14 @@ def test_chunks_to_events_empty_iterator() -> None:
|
||||
assert list(chunks_to_events(iter([]))) == []
|
||||
|
||||
|
||||
def test_chunks_to_events_block_transitions_close_previous_block() -> None:
|
||||
def test_chunks_to_events_block_transitions_keep_stable_indices() -> None:
|
||||
"""String-keyed blocks that transition mid-stream each get their own lifecycle.
|
||||
|
||||
Regression test for OpenAI `responses/v1` style streams where
|
||||
`content_blocks` uses string identifiers (e.g. `"lc_rs_305f30"`) to
|
||||
distinguish blocks. Each distinct block must get its own
|
||||
`content-block-start` / `content-block-finish` pair, with sequential
|
||||
`uint` wire indices, and blocks must not interleave.
|
||||
`uint` wire indices, and deltas keep that stable wire index.
|
||||
"""
|
||||
chunks = [
|
||||
ChatGenerationChunk(
|
||||
@@ -213,8 +215,8 @@ def test_chunks_to_events_block_transitions_close_previous_block() -> None:
|
||||
assert [s["index"] for s in starts] == [0, 1, 2]
|
||||
assert [f["index"] for f in finishes] == [0, 1, 2]
|
||||
|
||||
# Finish events must be interleaved with starts (no-interleave rule):
|
||||
# block 0 finishes before block 1 starts, etc.
|
||||
# Blocks are finalized at message end so providers can interleave
|
||||
# deltas for parallel content blocks without closing them early.
|
||||
events_any: list[Any] = events
|
||||
lifecycle = [
|
||||
(e["event"], e["index"])
|
||||
@@ -223,10 +225,10 @@ def test_chunks_to_events_block_transitions_close_previous_block() -> None:
|
||||
]
|
||||
assert lifecycle == [
|
||||
("content-block-start", 0),
|
||||
("content-block-finish", 0),
|
||||
("content-block-start", 1),
|
||||
("content-block-finish", 1),
|
||||
("content-block-start", 2),
|
||||
("content-block-finish", 0),
|
||||
("content-block-finish", 1),
|
||||
("content-block-finish", 2),
|
||||
]
|
||||
|
||||
@@ -297,6 +299,124 @@ def test_chunks_to_events_tool_call_multichunk() -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_chunks_to_events_interleaved_parallel_tool_calls() -> None:
|
||||
"""Parallel tool-call chunks can interleave without losing block lifecycles."""
|
||||
events = list(
|
||||
chunks_to_events(
|
||||
iter(_interleaved_parallel_tool_call_chunks()), message_id="msg-1"
|
||||
)
|
||||
)
|
||||
|
||||
_assert_interleaved_parallel_tool_call_events(events)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_achunks_to_events_interleaved_parallel_tool_calls() -> None:
|
||||
"""Async bridge preserves lifecycles for interleaved parallel tool calls."""
|
||||
events = [
|
||||
event
|
||||
async for event in achunks_to_events(
|
||||
_aiter_chunks(_interleaved_parallel_tool_call_chunks()),
|
||||
message_id="msg-1",
|
||||
)
|
||||
]
|
||||
|
||||
_assert_interleaved_parallel_tool_call_events(events)
|
||||
|
||||
|
||||
def _interleaved_parallel_tool_call_chunks() -> list[ChatGenerationChunk]:
|
||||
return [
|
||||
ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content="",
|
||||
id="msg-1",
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"index": 0,
|
||||
"id": "tc1",
|
||||
"name": "task",
|
||||
"args": '{"subagent_type": "haiku"',
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
)
|
||||
),
|
||||
ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content="",
|
||||
id="msg-1",
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"index": 1,
|
||||
"id": "tc2",
|
||||
"name": "task",
|
||||
"args": '{"subagent_type": "limerick"',
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
)
|
||||
),
|
||||
ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content="",
|
||||
id="msg-1",
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"index": 0,
|
||||
"id": None,
|
||||
"name": None,
|
||||
"args": ', "description": "Write a haiku"}',
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
)
|
||||
),
|
||||
ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content="",
|
||||
id="msg-1",
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"index": 1,
|
||||
"id": None,
|
||||
"name": None,
|
||||
"args": ', "description": "Write a limerick"}',
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
],
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
async def _aiter_chunks(
|
||||
chunks: list[ChatGenerationChunk],
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
|
||||
def _assert_interleaved_parallel_tool_call_events(events: list[Any]) -> None:
|
||||
assert_valid_event_stream(events)
|
||||
|
||||
starts: list[Any] = [e for e in events if e["event"] == "content-block-start"]
|
||||
finishes: list[Any] = [e for e in events if e["event"] == "content-block-finish"]
|
||||
assert [s["index"] for s in starts] == [0, 1]
|
||||
assert [f["index"] for f in finishes] == [0, 1]
|
||||
|
||||
finalized = [cast("ToolCall", event["content_block"]) for event in finishes]
|
||||
assert finalized[0]["id"] == "tc1"
|
||||
assert finalized[0]["args"] == {
|
||||
"subagent_type": "haiku",
|
||||
"description": "Write a haiku",
|
||||
}
|
||||
assert finalized[1]["id"] == "tc2"
|
||||
assert finalized[1]["args"] == {
|
||||
"subagent_type": "limerick",
|
||||
"description": "Write a limerick",
|
||||
}
|
||||
|
||||
|
||||
def test_chunks_to_events_invalid_tool_call_keeps_stop_reason() -> None:
|
||||
"""Malformed tool-args become invalid_tool_call; finish_reason stays `stop`."""
|
||||
chunks = [
|
||||
|
||||
@@ -5,9 +5,10 @@ or by the compat bridge's `chunks_to_events` / `message_to_events`)
|
||||
conforms to the protocol lifecycle rules:
|
||||
|
||||
- `message-start` opens and `message-finish` closes the stream.
|
||||
- Content blocks do not interleave: each block runs
|
||||
- Content blocks may interleave: each block index runs
|
||||
`content-block-start` → optional `content-block-delta`s →
|
||||
`content-block-finish` before the next block begins.
|
||||
`content-block-finish`, while other block indices may start or receive
|
||||
deltas before that block finishes.
|
||||
- Wire indices on content-block events are sequential `uint` values
|
||||
starting at 0.
|
||||
- For deltaable block types (`text`, `reasoning`, `tool_call_chunk`,
|
||||
@@ -71,7 +72,7 @@ def assert_valid_event_stream(events: Iterable[Any]) -> None:
|
||||
"`message-finish` must be the final event"
|
||||
)
|
||||
|
||||
open_idx: int | None = None
|
||||
open_indices: set[int] = set()
|
||||
expected_next_idx = 0
|
||||
start_events: dict[int, dict[str, Any]] = {}
|
||||
finish_events: dict[int, dict[str, Any]] = {}
|
||||
@@ -83,8 +84,9 @@ def assert_valid_event_stream(events: Iterable[Any]) -> None:
|
||||
assert i == 0, f"duplicate `message-start` at event {i}"
|
||||
continue
|
||||
if ev == "message-finish":
|
||||
assert open_idx is None, (
|
||||
f"`message-finish` while block {open_idx} still open (event {i})"
|
||||
assert not open_indices, (
|
||||
f"`message-finish` while blocks {sorted(open_indices)} "
|
||||
f"still open (event {i})"
|
||||
)
|
||||
continue
|
||||
if ev == "error":
|
||||
@@ -102,36 +104,39 @@ def assert_valid_event_stream(events: Iterable[Any]) -> None:
|
||||
assert idx == expected_next_idx, (
|
||||
f"expected next wire index {expected_next_idx}, got {idx} at event {i}"
|
||||
)
|
||||
assert open_idx is None, (
|
||||
f"content-block-start at idx={idx} while block {open_idx} "
|
||||
f"still open (event {i}); blocks must not interleave"
|
||||
assert idx not in start_events, (
|
||||
f"duplicate content-block-start for idx={idx} at event {i}"
|
||||
)
|
||||
open_idx = idx
|
||||
open_indices.add(idx)
|
||||
start_events[idx] = event["content_block"]
|
||||
delta_accum[idx] = {}
|
||||
expected_next_idx += 1
|
||||
elif ev == "content-block-delta":
|
||||
idx = event["index"]
|
||||
assert idx == open_idx, (
|
||||
f"content-block-delta at idx={idx} but currently-open block is "
|
||||
f"{open_idx} (event {i})"
|
||||
assert idx in open_indices, (
|
||||
f"content-block-delta at idx={idx} but that block is not open "
|
||||
f"(event {i})"
|
||||
)
|
||||
block = event["content_block"]
|
||||
_accumulate_delta(delta_accum[idx], block)
|
||||
elif ev == "content-block-finish":
|
||||
idx = event["index"]
|
||||
assert idx == open_idx, (
|
||||
f"content-block-finish at idx={idx} but currently-open block is "
|
||||
f"{open_idx} (event {i})"
|
||||
assert idx in open_indices, (
|
||||
f"content-block-finish at idx={idx} but that block is not open "
|
||||
f"(event {i})"
|
||||
)
|
||||
assert idx not in finish_events, (
|
||||
f"duplicate content-block-finish for idx={idx} at event {i}"
|
||||
)
|
||||
finish_events[idx] = event["content_block"]
|
||||
open_idx = None
|
||||
open_indices.remove(idx)
|
||||
else:
|
||||
# Unknown event types are accepted; the CDDL allows extensions.
|
||||
continue
|
||||
|
||||
assert open_idx is None, (
|
||||
f"block {open_idx} still open at end of stream — no content-block-finish"
|
||||
assert not open_indices, (
|
||||
f"blocks {sorted(open_indices)} still open at end of stream — "
|
||||
"no content-block-finish"
|
||||
)
|
||||
missing = set(start_events) - set(finish_events)
|
||||
assert not missing, (
|
||||
|
||||
Reference in New Issue
Block a user