fix(core): support content block tokens in callbacks (#34739)

Supersedes #34727
Closes #30703

Related:
* langchain-ai/langchain-google#1460
* langchain-ai/langchain-google#1501

Fixing this at the `langchain-core` callback layer instead of
normalizing inside individual provider integrations, so structured
streaming content is preserved consistently.

---

Models are increasingly streaming structured content blocks instead of
plain text tokens. For example, Gemini 3 can stream text as
content-block lists, and Anthropic/tool-use flows can also produce
non-text message content. Today those values already reach
`on_llm_new_token`, but the callback API still advertises `token: str`,
which makes custom callbacks, tracers, and streaming helpers assume
every streamed value is text.

User story: as a LangChain user building a streaming callback for chat
models with tool calls, reasoning/thinking blocks, or provider-specific
structured content, I need `on_llm_new_token` to accept the same content
shape that chat model chunks can actually emit, so my callback can
observe the stream without providers flattening or dropping non-text
data.

Fixing this in `langchain-core` makes the existing runtime behavior
explicit at the shared callback boundary. Normalizing content blocks
inside each provider would duplicate logic, produce inconsistent
behavior across integrations, and in some cases lose required provider
metadata such as Gemini thought signatures.

## Changes

- Update the callback contract so streamed tokens can be either plain
text or structured content blocks
- Carry structured streamed content through tracing and event/log
streaming paths without forcing provider data into text too early
- Keep built-in text-oriented streaming callbacks working by converting
structured tokens only at the display/queue boundary
- Drop the now-incorrect `cast("str", ...)` on streamed content in
`BaseChatModel` so the producer side matches the widened callback
signature instead of asserting a string it doesn't always have (no
runtime change — `cast` is erased)
- Align Anthropic and Mistral content typing with the structured content
shapes already used by chat model messages
- Update callback tests to reflect that not every streamed value is text

## Compatibility

No runtime behavior change: no producer emits anything it wasn't already
emitting, and widening a parameter type is safe for existing callers and
handlers that pass or receive `str`. The one caveat is downstream code
that subclasses a callback handler or tracer and overrides
`on_llm_new_token` with a `token: str` annotation — under strict type
checking that override is now narrower than the base and will be flagged
as incompatible with the supertype. Such code still runs unchanged; the
fix is to widen the annotation to match.
This commit is contained in:
Mason Daugherty
2026-06-10 16:59:08 -04:00
committed by GitHub
parent 720dfd3b09
commit f89f4c5afe
20 changed files with 75 additions and 57 deletions

View File

@@ -64,7 +64,7 @@ class LLMManagerMixin:
def on_llm_new_token(
self,
token: str,
token: str | list[str | dict[str, Any]],
*,
chunk: GenerationChunk | ChatGenerationChunk | None = None,
run_id: UUID,
@@ -79,7 +79,7 @@ class LLMManagerMixin:
For both chat models and non-chat models (legacy text completion LLMs).
Args:
token: The new token.
token: The new token, or a list of content blocks.
chunk: The new generated chunk, containing content and other information.
run_id: The ID of the current run.
parent_run_id: The ID of the parent run.
@@ -631,7 +631,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_llm_new_token(
self,
token: str,
token: str | list[str | dict[str, Any]],
*,
chunk: GenerationChunk | ChatGenerationChunk | None = None,
run_id: UUID,
@@ -644,7 +644,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
For both chat models and non-chat models (legacy text completion LLMs).
Args:
token: The new token.
token: The new token, or a list of content blocks.
chunk: The new generated chunk, containing content and other information.
run_id: The ID of the current run.
parent_run_id: The ID of the parent run.

View File

@@ -668,7 +668,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
def on_llm_new_token(
self,
token: str,
token: str | list[str | dict[str, Any]],
*,
chunk: GenerationChunk | ChatGenerationChunk | None = None,
**kwargs: Any,
@@ -676,7 +676,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
"""Run when LLM generates a new token.
Args:
token: The new token.
token: The new token, or a list of content blocks.
chunk: The chunk.
**kwargs: Additional keyword arguments.
@@ -787,7 +787,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
async def on_llm_new_token(
self,
token: str,
token: str | list[str | dict[str, Any]],
*,
chunk: GenerationChunk | ChatGenerationChunk | None = None,
**kwargs: Any,
@@ -795,7 +795,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
"""Run when LLM generates a new token.
Args:
token: The new token.
token: The new token, or a list of content blocks.
chunk: The chunk.
**kwargs: Additional keyword arguments.

View File

@@ -47,14 +47,16 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
"""
@override
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
def on_llm_new_token(
self, token: str | list[str | dict[str, Any]], **kwargs: Any
) -> None:
"""Run on new LLM token. Only available when streaming is enabled.
Args:
token: The new token.
token: The new token, or a list of content blocks.
**kwargs: Additional keyword arguments.
"""
sys.stdout.write(token)
sys.stdout.write(str(token))
sys.stdout.flush()
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:

View File

@@ -792,9 +792,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
index += 1
if "index" not in block:
block["index"] = index
run_manager.on_llm_new_token(
cast("str", chunk.message.content), chunk=chunk
)
run_manager.on_llm_new_token(chunk.message.content, chunk=chunk)
chunks.append(chunk)
yield cast("AIMessageChunk", chunk.message)
yielded = True
@@ -927,9 +925,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
index += 1
if "index" not in block:
block["index"] = index
await run_manager.on_llm_new_token(
cast("str", chunk.message.content), chunk=chunk
)
await run_manager.on_llm_new_token(chunk.message.content, chunk=chunk)
chunks.append(chunk)
yield cast("AIMessageChunk", chunk.message)
yielded = True
@@ -1972,9 +1968,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
if run_manager:
if chunk.message.id is None:
chunk.message.id = run_id
run_manager.on_llm_new_token(
cast("str", chunk.message.content), chunk=chunk
)
run_manager.on_llm_new_token(chunk.message.content, chunk=chunk)
chunks.append(chunk)
yielded = True
@@ -2130,7 +2124,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
if chunk.message.id is None:
chunk.message.id = run_id
await run_manager.on_llm_new_token(
cast("str", chunk.message.content), chunk=chunk
chunk.message.content, chunk=chunk
)
chunks.append(chunk)
yielded = True

View File

@@ -149,7 +149,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
@override
def on_llm_new_token(
self,
token: str,
token: str | list[str | dict[str, Any]],
*,
chunk: GenerationChunk | ChatGenerationChunk | None = None,
run_id: UUID,
@@ -161,7 +161,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
Only available when streaming is enabled.
Args:
token: The token.
token: The token, or a list of content blocks for structured output.
chunk: The chunk.
run_id: The run ID.
parent_run_id: The parent run ID.
@@ -645,7 +645,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
@override
async def on_llm_new_token(
self,
token: str,
token: str | list[str | dict[str, Any]],
*,
chunk: GenerationChunk | ChatGenerationChunk | None = None,
run_id: UUID,
@@ -919,7 +919,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
async def _on_llm_new_token(
self,
run: Run,
token: str,
token: str | list[str | dict[str, Any]],
chunk: GenerationChunk | ChatGenerationChunk | None,
) -> None:
"""Process new LLM token."""

View File

@@ -242,7 +242,7 @@ class _TracerCore(ABC):
def _llm_run_with_token_event(
self,
token: str,
token: str | list[str | dict[str, Any]],
run_id: UUID,
chunk: GenerationChunk | ChatGenerationChunk | None = None,
parent_run_id: UUID | None = None,
@@ -602,14 +602,14 @@ class _TracerCore(ABC):
def _on_llm_new_token(
self,
run: Run,
token: str,
token: str | list[str | dict[str, Any]],
chunk: GenerationChunk | ChatGenerationChunk | None,
) -> Coroutine[Any, Any, None] | None:
"""Process new LLM token.
Args:
run: The LLM run.
token: The new token.
token: The new token, or a list of content blocks.
chunk: Optional chunk.
"""
_ = (run, token, chunk)

View File

@@ -429,7 +429,7 @@ class _AstreamEventsCallbackHandler(
@override
async def on_llm_new_token(
self,
token: str,
token: str | list[str | dict[str, Any]],
*,
chunk: GenerationChunk | ChatGenerationChunk | None = None,
run_id: UUID,
@@ -465,7 +465,8 @@ class _AstreamEventsCallbackHandler(
elif run_info["run_type"] == "llm":
event = "on_llm_stream"
if chunk is None:
chunk_ = GenerationChunk(text=token)
text = token if isinstance(token, str) else ""
chunk_ = GenerationChunk(text=text)
else:
chunk_ = cast("GenerationChunk", chunk)
else:

View File

@@ -359,7 +359,7 @@ class LangChainTracer(BaseTracer):
@override
def _llm_run_with_token_event(
self,
token: str,
token: str | list[str | dict[str, Any]],
run_id: UUID,
chunk: GenerationChunk | ChatGenerationChunk | None = None,
parent_run_id: UUID | None = None,

View File

@@ -537,7 +537,7 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler[Any]):
def _on_llm_new_token(
self,
run: Run,
token: str,
token: str | list[str | dict[str, Any]],
chunk: GenerationChunk | ChatGenerationChunk | None,
) -> None:
"""Process new LLM token."""

View File

@@ -33,7 +33,7 @@ class MyCustomAsyncHandler(AsyncCallbackHandler):
@override
async def on_llm_new_token(
self,
token: str,
token: str | list[str | dict[str, Any]],
*,
chunk: GenerationChunk | ChatGenerationChunk | None = None,
run_id: UUID,

View File

@@ -182,7 +182,7 @@ async def test_callback_handlers() -> None:
@override
async def on_llm_new_token(
self,
token: str,
token: str | list[str | dict[str, Any]],
*,
chunk: GenerationChunk | ChatGenerationChunk | None = None,
run_id: UUID,
@@ -190,7 +190,8 @@ async def test_callback_handlers() -> None:
tags: list[str] | None = None,
**kwargs: Any,
) -> None:
self.store.append(token)
if isinstance(token, str):
self.store.append(token)
infinite_cycle = cycle(
[

View File

@@ -52,19 +52,22 @@ lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/langchain -
lint_package: PYTHON_FILES=langchain_classic
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
UV_RUN_LINT = uv run --all-groups
UV_RUN_TYPE = uv run --all-groups
lint_package lint_tests: UV_RUN_LINT = uv run --group lint
lint lint_diff lint_package lint_tests:
./scripts/lint_imports.sh
[ "$(PYTHON_FILES)" = "" ] || uv run --group lint --group typing ruff check $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || uv run --group lint --group typing ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && uv run --group lint --group typing mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
[ "$(PYTHON_FILES)" = "" ] || $(UV_RUN_LINT) ruff check $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || $(UV_RUN_LINT) ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && $(UV_RUN_TYPE) mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
type:
mkdir -p $(MYPY_CACHE) && uv run --group lint --group typing mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
mkdir -p $(MYPY_CACHE) && $(UV_RUN_TYPE) mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
[ "$(PYTHON_FILES)" = "" ] || uv run --group lint --group typing ruff format $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || uv run --group lint --group typing ruff check --fix $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || $(UV_RUN_LINT) ruff format $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || $(UV_RUN_LINT) ruff check --fix $(PYTHON_FILES)
######################
# HELP

View File

@@ -39,9 +39,12 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
self.done.clear()
@override
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
if token is not None and token != "":
self.queue.put_nowait(token)
async def on_llm_new_token(
self, token: str | list[str | dict[str, Any]], **kwargs: Any
) -> None:
token_str = token if isinstance(token, str) else str(token)
if token_str != "":
self.queue.put_nowait(token_str)
@override
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:

View File

@@ -81,9 +81,13 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
self.done.set()
@override
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
async def on_llm_new_token(
self, token: str | list[str | dict[str, Any]], **kwargs: Any
) -> None:
token_str = token if isinstance(token, str) else str(token)
# Remember the last n tokens, where n = len(answer_prefix_tokens)
self.append_to_last_tokens(token)
self.append_to_last_tokens(token_str)
# Check if the last n tokens match the answer_prefix_tokens list ...
if self.check_if_answer_reached():
@@ -95,4 +99,4 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
# If yes, then put tokens from now on
if self.answer_reached:
self.queue.put_nowait(token)
self.queue.put_nowait(token_str)

View File

@@ -76,10 +76,14 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
self.answer_reached = False
@override
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
def on_llm_new_token(
self, token: str | list[str | dict[str, Any]], **kwargs: Any
) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
token_str = token if isinstance(token, str) else str(token)
# Remember the last n tokens, where n = len(answer_prefix_tokens)
self.append_to_last_tokens(token)
self.append_to_last_tokens(token_str)
# Check if the last n tokens match the answer_prefix_tokens list ...
if self.check_if_answer_reached():
@@ -92,5 +96,5 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
# ... if yes, then print tokens from now on
if self.answer_reached:
sys.stdout.write(token)
sys.stdout.write(token_str)
sys.stdout.flush()

View File

@@ -173,7 +173,7 @@ async def test_callback_handlers() -> None:
@override
async def on_llm_new_token(
self,
token: str,
token: str | list[str | dict[str, Any]],
*,
chunk: GenerationChunk | ChatGenerationChunk | None = None,
run_id: UUID,
@@ -181,7 +181,8 @@ async def test_callback_handlers() -> None:
tags: list[str] | None = None,
**kwargs: Any,
) -> None:
self.store.append(token)
if isinstance(token, str):
self.store.append(token)
infinite_cycle = cycle(
[

View File

@@ -77,7 +77,7 @@ def _extract_tool_calls_from_message(message: AIMessage) -> list[ToolCall]:
return extract_tool_calls(message.content)
def extract_tool_calls(content: str | list[str | dict]) -> list[ToolCall]:
def extract_tool_calls(content: str | list[str | dict[str, Any]]) -> list[ToolCall]:
"""Extract tool calls from a list of content blocks."""
if isinstance(content, list):
tool_calls = []

View File

@@ -737,7 +737,7 @@ def test__format_messages_with_tool_calls() -> None:
assert expected == actual
# Check handling of empty AIMessage
empty_contents: list[str | list[str | dict]] = ["", []]
empty_contents: list[str | list[str | dict[str, Any]]] = ["", []]
for empty_content in empty_contents:
## Permit message in final position
_, anthropic_messages = _format_messages([human, AIMessage(empty_content)])

View File

@@ -2,6 +2,8 @@
from __future__ import annotations
from typing import Any
from langchain_core.messages import AIMessage, AIMessageChunk
from langchain_core.messages import content as types
from langchain_core.messages.block_translators import register_translator
@@ -10,7 +12,7 @@ from langchain_core.messages.block_translators import register_translator
def _convert_from_v1_to_mistral(
content: list[types.ContentBlock],
model_provider: str | None,
) -> str | list[str | dict]:
) -> str | list[str | dict[str, Any]]:
new_content: list = []
for block in content:
if block["type"] == "text":

View File

@@ -339,8 +339,11 @@ async def mock_chat_astream(*args: Any, **kwargs: Any) -> AsyncGenerator:
class MyCustomHandler(BaseCallbackHandler):
last_token: str = ""
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
self.last_token = token
def on_llm_new_token(
self, token: str | list[str | dict[str, Any]], **kwargs: Any
) -> None:
if isinstance(token, str):
self.last_token = token
@patch(