feat(core): align compat bridge with protocol v0.0.11

- Rename content-block imports to new protocol names (TextContentBlock,
  ReasoningContentBlock, InvalidToolCall, ToolCall, ToolCallChunk,
  ServerToolCall, ServerToolCallChunk).
- Drop FinishReason and _normalize_finish_reason: the protocol removed
  ``reason`` from ``MessageFinishData`` in
  2ef8585659.
  Provider-level ``finish_reason`` / ``stop_reason`` now pass through
  verbatim on ``MessageFinishData.metadata`` for downstream consumers.
- Simplify ``_build_message_finish`` and ``_finish_all_blocks``: the
  tool_use re-classification previously driven by the finish reason is
  obsolete now that the wire field is gone.
- Drop the ``_finish_reason`` accumulator from chat_model_stream: the
  same data is surfaced via ``response_metadata`` through the passed-
  through finish metadata.

Made-with: Cursor
This commit is contained in:
Christian Bromann
2026-04-23 16:36:15 -07:00
parent bd8ab5520b
commit 8c404418cf
2 changed files with 57 additions and 119 deletions

View File

@@ -42,18 +42,17 @@ from langchain_protocol.protocol import (
ContentBlockFinishData,
ContentBlockStartData,
FinalizedContentBlock,
FinishReason,
InvalidToolCallBlock,
InvalidToolCall,
MessageFinishData,
MessageMetadata,
MessagesData,
MessageStartData,
ReasoningBlock,
ServerToolCallBlock,
ServerToolCallChunkBlock,
TextBlock,
ToolCallBlock,
ToolCallChunkBlock,
ReasoningContentBlock,
ServerToolCall,
ServerToolCallChunk,
TextContentBlock,
ToolCall,
ToolCallChunk,
UsageInfo,
)
@@ -160,18 +159,18 @@ def _start_skeleton(block: CompatBlock) -> ContentBlock:
"""
btype = block.get("type", "text")
if btype == "text":
return TextBlock(type="text", text="")
return TextContentBlock(type="text", text="")
if btype == "reasoning":
return ReasoningBlock(type="reasoning", reasoning="")
return ReasoningContentBlock(type="reasoning", reasoning="")
if btype == "tool_call_chunk":
skel = ToolCallChunkBlock(type="tool_call_chunk", args="")
skel = ToolCallChunk(type="tool_call_chunk", args="")
if block.get("id") is not None:
skel["id"] = block["id"]
if block.get("name") is not None:
skel["name"] = block["name"]
return skel
if btype == "server_tool_call_chunk":
s_skel = ServerToolCallChunkBlock(
s_skel = ServerToolCallChunk(
type="server_tool_call_chunk",
args="",
)
@@ -275,7 +274,7 @@ def _finalize_block(block: CompatBlock) -> FinalizedContentBlock:
try:
parsed = json.loads(raw) if raw else {}
except (json.JSONDecodeError, TypeError):
invalid = InvalidToolCallBlock(
invalid = InvalidToolCall(
type="invalid_tool_call",
args=raw,
error="Failed to parse tool call arguments as JSON",
@@ -286,13 +285,13 @@ def _finalize_block(block: CompatBlock) -> FinalizedContentBlock:
invalid["name"] = block["name"]
return invalid
if btype == "tool_call_chunk":
return ToolCallBlock(
return ToolCall(
type="tool_call",
id=block.get("id", ""),
name=block.get("name", ""),
args=parsed,
)
return ServerToolCallBlock(
return ServerToolCall(
type="server_tool_call",
id=block.get("id", ""),
name=block.get("name", ""),
@@ -316,17 +315,6 @@ def _extract_start_metadata(response_metadata: dict[str, Any]) -> MessageMetadat
return metadata
def _normalize_finish_reason(value: Any) -> FinishReason:
"""Map provider-specific stop reasons to protocol finish reasons."""
if value == "length":
return "length"
if value == "content_filter":
return "content_filter"
if value in ("tool_use", "tool_calls"):
return "tool_use"
return "stop"
def _accumulate_usage(
current: dict[str, Any] | None, delta: Any
) -> dict[str, Any] | None:
@@ -378,45 +366,29 @@ def _build_message_start(
def _build_message_finish(
*,
finish_reason: FinishReason,
has_valid_tool_call: bool,
usage: dict[str, Any] | None,
response_metadata: dict[str, Any] | None,
) -> MessageFinishData:
# Infer tool_use only from finalized (parsed) tool_calls. An
# invalid_tool_call means parsing failed — the model didn't
# successfully request a tool, so leave finish_reason alone.
if finish_reason == "stop" and has_valid_tool_call:
finish_reason = "tool_use"
finish_data = MessageFinishData(event="message-finish", reason=finish_reason)
# Protocol v0.0.11 dropped ``reason`` from ``MessageFinishData``.
# ``finish_reason`` / ``stop_reason`` are still surfaced via
# ``metadata`` for consumers that want the raw provider hint; the
# wire event itself no longer advertises a normalized reason.
finish_data = MessageFinishData(event="message-finish")
usage_info = _to_protocol_usage(usage)
if usage_info is not None:
finish_data["usage"] = usage_info
if response_metadata:
metadata = {
k: v
for k, v in response_metadata.items()
if k not in ("finish_reason", "stop_reason")
}
if metadata:
finish_data["metadata"] = metadata
finish_data["metadata"] = dict(response_metadata)
return finish_data
def _finish_all_blocks(
state: dict[int, CompatBlock],
) -> tuple[list[MessagesData], bool]:
"""Emit `content-block-finish` events for every open block.
Returns the event list plus a flag indicating whether any finalized
block was a valid `tool_call` (used for finish-reason inference).
"""
) -> list[MessagesData]:
"""Emit `content-block-finish` events for every open block."""
events: list[MessagesData] = []
has_valid_tool_call = False
for idx in sorted(state):
finalized = _finalize_block(state[idx])
if finalized.get("type") == "tool_call":
has_valid_tool_call = True
events.append(
ContentBlockFinishData(
event="content-block-finish",
@@ -424,7 +396,7 @@ def _finish_all_blocks(
content_block=finalized,
)
)
return events, has_valid_tool_call
return events
# ---------------------------------------------------------------------------
@@ -451,7 +423,6 @@ def chunks_to_events(
first_seen: set[int] = set()
usage: dict[str, Any] | None = None
response_metadata: dict[str, Any] = {}
finish_reason: FinishReason = "stop"
for chunk in chunks:
msg = chunk.message
@@ -484,19 +455,11 @@ def chunks_to_events(
if msg.usage_metadata:
usage = _accumulate_usage(usage, msg.usage_metadata)
rm = msg.response_metadata or {}
raw_reason = rm.get("finish_reason") or rm.get("stop_reason")
if raw_reason:
finish_reason = _normalize_finish_reason(raw_reason)
if not started:
return
finish_events, has_valid_tool_call = _finish_all_blocks(state)
yield from finish_events
yield from _finish_all_blocks(state)
yield _build_message_finish(
finish_reason=finish_reason,
has_valid_tool_call=has_valid_tool_call,
usage=usage,
response_metadata=response_metadata,
)
@@ -513,7 +476,6 @@ async def achunks_to_events(
first_seen: set[int] = set()
usage: dict[str, Any] | None = None
response_metadata: dict[str, Any] = {}
finish_reason: FinishReason = "stop"
async for chunk in chunks:
msg = chunk.message
@@ -546,20 +508,12 @@ async def achunks_to_events(
if msg.usage_metadata:
usage = _accumulate_usage(usage, msg.usage_metadata)
rm = msg.response_metadata or {}
raw_reason = rm.get("finish_reason") or rm.get("stop_reason")
if raw_reason:
finish_reason = _normalize_finish_reason(raw_reason)
if not started:
return
finish_events, has_valid_tool_call = _finish_all_blocks(state)
for event in finish_events:
for event in _finish_all_blocks(state):
yield event
yield _build_message_finish(
finish_reason=finish_reason,
has_valid_tool_call=has_valid_tool_call,
usage=usage,
response_metadata=response_metadata,
)
@@ -592,7 +546,6 @@ def message_to_events(
response_metadata = msg.response_metadata or {}
yield _build_message_start(msg, message_id)
has_valid_tool_call = False
for idx, block in _iter_protocol_blocks(msg):
yield ContentBlockStartData(
event="content-block-start",
@@ -605,24 +558,13 @@ def message_to_events(
index=idx,
content_block=_to_protocol_delta_block(block),
)
finalized = _finalize_block(block)
if finalized.get("type") == "tool_call":
has_valid_tool_call = True
yield ContentBlockFinishData(
event="content-block-finish",
index=idx,
content_block=finalized,
content_block=_finalize_block(block),
)
raw_reason = response_metadata.get("finish_reason") or response_metadata.get(
"stop_reason"
)
finish_reason: FinishReason = (
_normalize_finish_reason(raw_reason) if raw_reason else "stop"
)
yield _build_message_finish(
finish_reason=finish_reason,
has_valid_tool_call=has_valid_tool_call,
usage=getattr(msg, "usage_metadata", None),
response_metadata=response_metadata,
)

View File

@@ -23,15 +23,15 @@ from typing import TYPE_CHECKING, Any, cast
from langchain_protocol.protocol import (
ContentBlockDeltaData,
ContentBlockFinishData,
InvalidToolCallBlock,
InvalidToolCall,
MessageFinishData,
MessageMetadata,
MessageStartData,
ReasoningBlock,
ServerToolCallChunkBlock,
TextBlock,
ToolCallBlock,
ToolCallChunkBlock,
ReasoningContentBlock,
ServerToolCallChunk,
TextContentBlock,
ToolCall,
ToolCallChunk,
UsageInfo,
)
@@ -68,8 +68,8 @@ def _sweep_chunk_store(
*,
finalized_type: str,
finalized_blocks: dict[int, FinalizedContentBlock],
tool_calls_acc: list[ToolCallBlock] | None,
invalid_acc: list[InvalidToolCallBlock],
tool_calls_acc: list[ToolCall] | None,
invalid_acc: list[InvalidToolCall],
) -> None:
"""Parse each unswept chunk's `args`; record as `finalized_type` or invalid.
@@ -82,7 +82,7 @@ def _sweep_chunk_store(
try:
parsed = json.loads(raw_args) if raw_args else {}
except (json.JSONDecodeError, TypeError):
invalid: InvalidToolCallBlock = {
invalid: InvalidToolCall = {
"type": "invalid_tool_call",
"args": raw_args or "",
"error": "Failed to parse tool call arguments as JSON",
@@ -104,7 +104,7 @@ def _sweep_chunk_store(
},
)
if tool_calls_acc is not None and finalized_type == "tool_call":
tool_calls_acc.append(cast("ToolCallBlock", final_block))
tool_calls_acc.append(cast("ToolCall", final_block))
finalized_blocks[idx] = final_block
store.clear()
@@ -416,8 +416,8 @@ class ChatModelStream:
- `.text` — iterable of `str` deltas; `str()` for full text
- `.reasoning` — same as `.text` for reasoning content
- `.tool_calls` — iterable of `ToolCallChunkBlock` deltas;
`.get()` returns `list[ToolCallBlock]`
- `.tool_calls` — iterable of `ToolCallChunk` deltas;
`.get()` returns `list[ToolCall]`
- `.usage` — blocking property, returns `UsageInfo | None`
- `.output` — blocking property, returns assembled `AIMessage`
@@ -442,8 +442,8 @@ class ChatModelStream:
self._text_acc: str = ""
self._reasoning_acc: str = ""
self._tool_call_chunks: dict[int, dict[str, Any]] = {}
self._tool_calls_acc: list[ToolCallBlock] = []
self._invalid_tool_calls_acc: list[InvalidToolCallBlock] = []
self._tool_calls_acc: list[ToolCall] = []
self._invalid_tool_calls_acc: list[InvalidToolCall] = []
self._server_tool_call_chunks: dict[int, dict[str, Any]] = {}
# Ordered snapshot of every finalized block, keyed by event index.
# Single source of truth for .output.content. Typed accumulators
@@ -451,7 +451,6 @@ class ChatModelStream:
# the public projections.
self._blocks: dict[int, FinalizedContentBlock] = {}
self._usage_value: UsageInfo | None = None
self._finish_reason: str | None = None
self._start_metadata: MessageMetadata | None = None
self._finish_metadata: dict[str, Any] | None = None
self._done: bool = False
@@ -504,9 +503,9 @@ class ChatModelStream:
@property
def tool_calls(self) -> SyncProjection:
"""Tool calls — iterable of `ToolCallChunkBlock` deltas.
"""Tool calls — iterable of `ToolCallChunk` deltas.
`.get()` returns finalized `list[ToolCallBlock]`.
`.get()` returns finalized `list[ToolCall]`.
"""
return self._tool_calls_proj
@@ -620,7 +619,7 @@ class ChatModelStream:
# If the source exhausted without a message-finish event
# (e.g., empty response), finalize with what we have.
if not self._done:
self._finish(MessageFinishData(event="message-finish", reason="stop"))
self._finish(MessageFinishData(event="message-finish"))
# -- Internal push API (called by dispatch) ----------------------------
@@ -640,19 +639,19 @@ class ChatModelStream:
btype = block.get("type", "")
if btype == "text":
text_block = cast("TextBlock", block)
text_block = cast("TextContentBlock", block)
delta_text = text_block.get("text", "")
if delta_text:
self._text_acc += delta_text
self._text_proj.push(delta_text)
elif btype == "reasoning":
reasoning_block = cast("ReasoningBlock", block)
reasoning_block = cast("ReasoningContentBlock", block)
delta_r = reasoning_block.get("reasoning", "")
if delta_r:
self._reasoning_acc += delta_r
self._reasoning_proj.push(delta_r)
elif btype == "tool_call_chunk":
tcc = cast("ToolCallChunkBlock", block)
tcc = cast("ToolCallChunk", block)
# The protocol puts the block index on the event
# (``ContentBlockDeltaData``), not inside ``content_block``.
# Fall back to ``content_block.index`` for providers that echo
@@ -661,7 +660,7 @@ class ChatModelStream:
if idx is None:
idx = tcc.get("index", len(self._tool_call_chunks))
_merge_chunk_into_store(self._tool_call_chunks, idx, dict(tcc))
chunk_block = ToolCallChunkBlock(type="tool_call_chunk")
chunk_block = ToolCallChunk(type="tool_call_chunk")
if tcc.get("id"):
chunk_block["id"] = tcc["id"]
if tcc.get("name"):
@@ -672,7 +671,7 @@ class ChatModelStream:
chunk_block["index"] = tcc["index"]
self._tool_calls_proj.push(chunk_block)
elif btype == "server_tool_call_chunk":
stcc = cast("ServerToolCallChunkBlock", block)
stcc = cast("ServerToolCallChunk", block)
idx = data.get("index")
if idx is None:
idx = len(self._server_tool_call_chunks)
@@ -692,7 +691,7 @@ class ChatModelStream:
finalized: FinalizedContentBlock | None = None
if btype == "text":
text_block = cast("TextBlock", block)
text_block = cast("TextContentBlock", block)
full_text = text_block.get("text", "")
if full_text and full_text != self._text_acc:
self._text_acc = full_text
@@ -701,7 +700,7 @@ class ChatModelStream:
{"type": "text", "text": self._text_acc},
)
elif btype == "reasoning":
reasoning_block = cast("ReasoningBlock", block)
reasoning_block = cast("ReasoningContentBlock", block)
full_r = reasoning_block.get("reasoning", "")
if full_r and full_r != self._reasoning_acc:
self._reasoning_acc = full_r
@@ -710,8 +709,8 @@ class ChatModelStream:
{"type": "reasoning", "reasoning": self._reasoning_acc},
)
elif btype == "tool_call":
tcb = cast("ToolCallBlock", block)
tc = ToolCallBlock(
tcb = cast("ToolCall", block)
tc = ToolCall(
type="tool_call",
id=tcb.get("id", ""),
name=tcb.get("name", ""),
@@ -722,10 +721,10 @@ class ChatModelStream:
del self._tool_call_chunks[idx]
finalized = tc
elif btype == "invalid_tool_call":
itc = cast("InvalidToolCallBlock", block)
itc = cast("InvalidToolCall", block)
self._invalid_tool_calls_acc.append(itc)
# Critical: drop the stale chunk so _finish's sweep doesn't revive
# it as an empty-args ToolCallBlock.
# it as an empty-args ToolCall.
if idx is not None and idx in self._tool_call_chunks:
del self._tool_call_chunks[idx]
if idx is not None and idx in self._server_tool_call_chunks:
@@ -751,7 +750,6 @@ class ChatModelStream:
"""Process a `message-finish` event."""
self._done = True
self._usage_value = data.get("usage")
self._finish_reason = data.get("reason")
self._finish_metadata = data.get("metadata")
# Finalize any unswept chunks — both client- and server-side.
@@ -803,13 +801,11 @@ class ChatModelStream:
else:
ordered_blocks = [self._blocks[idx] for idx in sorted(self._blocks)]
if len(ordered_blocks) == 1 and ordered_blocks[0].get("type") == "text":
content = cast("TextBlock", ordered_blocks[0]).get("text", "")
content = cast("TextContentBlock", ordered_blocks[0]).get("text", "")
else:
content = [dict(b) for b in ordered_blocks]
response_metadata: dict[str, Any] = {"output_version": "v1"}
if self._finish_reason:
response_metadata["finish_reason"] = self._finish_reason
if self._start_metadata:
if "provider" in self._start_metadata:
response_metadata["model_provider"] = self._start_metadata["provider"]
@@ -864,8 +860,8 @@ class AsyncChatModelStream(ChatModelStream):
- `.text` — async iterable of text deltas; awaitable for full text
- `.reasoning` — async iterable of reasoning deltas; awaitable
- `.tool_calls` — async iterable of `ToolCallChunkBlock` deltas;
awaitable for `list[ToolCallBlock]`
- `.tool_calls` — async iterable of `ToolCallChunk` deltas;
awaitable for `list[ToolCall]`
- `.usage` — awaitable for `UsageInfo`
- `.output` — awaitable for assembled `AIMessage`