Compare commits

...

1 Commits

Author SHA1 Message Date
Mason Daugherty
89943726d3 fix(core): (WIP) google tool call streaming fixes 2026-01-14 22:47:52 -05:00
2 changed files with 109 additions and 0 deletions

View File

@@ -542,12 +542,41 @@ def translate_content(message: AIMessage) -> list[types.ContentBlock]:
def translate_content_chunk(message: AIMessageChunk) -> list[types.ContentBlock]:
"""Derive standard content blocks from a chunk with Google (GenAI) content.
For intermediate streaming chunks (`chunk_position != 'last'`), tool calls are
returned as `tool_call_chunk` blocks. For final/aggregated chunks, tool calls are
returned as `tool_call` blocks.
Args:
message: The message chunk to translate.
Returns:
The derived content blocks.
"""
# For intermediate streaming chunks with tool_call_chunks,
# produce tool_call_chunk blocks
if message.tool_call_chunks and message.chunk_position != "last":
# Start with non-tool-call content from the base conversion
blocks = _convert_to_v1_from_genai(message)
# Remove any tool_call blocks added by _convert_to_v1_from_genai
# (they were derived from message.tool_calls which auto-parses tool_call_chunks)
blocks = [b for b in blocks if b.get("type") != "tool_call"]
# Add tool_call_chunk blocks from tool_call_chunks
for chunk in message.tool_call_chunks:
tool_call_chunk_block: types.ToolCallChunk = {
"type": "tool_call_chunk",
"id": chunk.get("id"),
"name": chunk.get("name"),
"args": chunk.get("args"),
}
if (idx := chunk.get("index")) is not None:
tool_call_chunk_block["index"] = idx
blocks.append(tool_call_chunk_block)
return blocks
# Final/aggregated chunk or no tool_call_chunks: use standard conversion
return _convert_to_v1_from_genai(message)

View File

@@ -1,8 +1,11 @@
"""Tests for Google GenAI block translator."""
from langchain_core.messages import AIMessageChunk
from langchain_core.messages.block_translators.google_genai import (
translate_content_chunk,
translate_grounding_metadata_to_citations,
)
from langchain_core.messages.tool import tool_call_chunk
def test_translate_grounding_metadata_web() -> None:
@@ -216,3 +219,80 @@ def test_translate_grounding_metadata_multiple_chunks() -> None:
assert (
citations[1].get("extras", {})["google_ai_metadata"]["place_id"] == "places/123"
)
def test_translate_content_chunk_intermediate_streaming() -> None:
"""Intermediate chunks should have `tool_call_chunk` in `content_blocks`."""
chunk = AIMessageChunk(
content=[],
tool_call_chunks=[
tool_call_chunk(name="my_tool", args='{"arg": "value"}', id="123", index=0)
],
response_metadata={"model_provider": "google_genai"},
# No chunk_position set (intermediate chunk)
)
blocks = translate_content_chunk(chunk)
tool_blocks = [b for b in blocks if b.get("type") == "tool_call_chunk"]
assert len(tool_blocks) == 1
assert tool_blocks[0].get("name") == "my_tool"
assert tool_blocks[0].get("args") == '{"arg": "value"}'
assert tool_blocks[0].get("index") == 0
def test_translate_content_chunk_final_chunk() -> None:
"""Final chunks should have `tool_call` in `content_blocks`."""
chunk = AIMessageChunk(
content=[],
tool_call_chunks=[
tool_call_chunk(name="my_tool", args='{"arg": "value"}', id="123")
],
response_metadata={"model_provider": "google_genai"},
chunk_position="last", # Final chunk
)
blocks = translate_content_chunk(chunk)
tool_blocks = [b for b in blocks if b.get("type") == "tool_call"]
assert len(tool_blocks) == 1
assert tool_blocks[0].get("name") == "my_tool"
def test_translate_content_chunk_multiple_tool_calls() -> None:
"""Test intermediate chunk with multiple `tool_call_chunks`."""
chunk = AIMessageChunk(
content=[],
tool_call_chunks=[
tool_call_chunk(name="tool_a", args='{"a": 1}', id="1", index=0),
tool_call_chunk(name="tool_b", args='{"b": 2}', id="2", index=1),
],
response_metadata={"model_provider": "google_genai"},
)
blocks = translate_content_chunk(chunk)
tool_blocks = [b for b in blocks if b.get("type") == "tool_call_chunk"]
assert len(tool_blocks) == 2
assert tool_blocks[0].get("name") == "tool_a"
assert tool_blocks[0].get("index") == 0
assert tool_blocks[1].get("name") == "tool_b"
assert tool_blocks[1].get("index") == 1
def test_translate_content_chunk_with_text_and_tool_call() -> None:
"""Test intermediate chunk with both text `content` and `tool_call_chunks`."""
chunk = AIMessageChunk(
content=[{"type": "text", "text": "Let me call a tool."}],
tool_call_chunks=[
tool_call_chunk(name="my_tool", args='{"arg": "value"}', id="123", index=0)
],
response_metadata={"model_provider": "google_genai"},
)
blocks = translate_content_chunk(chunk)
text_blocks = [b for b in blocks if b.get("type") == "text"]
tool_chunk_blocks = [b for b in blocks if b.get("type") == "tool_call_chunk"]
assert len(text_blocks) == 1
assert text_blocks[0].get("text") == "Let me call a tool."
assert len(tool_chunk_blocks) == 1
assert tool_chunk_blocks[0].get("name") == "my_tool"