mirror of
https://github.com/hwchase17/langchain.git
synced 2026-07-01 06:42:37 +00:00
fix(core): support content block tokens in callbacks (#34739)
Supersedes #34727 Closes #30703 Related: * langchain-ai/langchain-google#1460 * langchain-ai/langchain-google#1501 Fixing this at the `langchain-core` callback layer instead of normalizing inside individual provider integrations, so structured streaming content is preserved consistently. --- Models are increasingly streaming structured content blocks instead of plain text tokens. For example, Gemini 3 can stream text as content-block lists, and Anthropic/tool-use flows can also produce non-text message content. Today those values already reach `on_llm_new_token`, but the callback API still advertises `token: str`, which makes custom callbacks, tracers, and streaming helpers assume every streamed value is text. User story: as a LangChain user building a streaming callback for chat models with tool calls, reasoning/thinking blocks, or provider-specific structured content, I need `on_llm_new_token` to accept the same content shape that chat model chunks can actually emit, so my callback can observe the stream without providers flattening or dropping non-text data. Fixing this in `langchain-core` makes the existing runtime behavior explicit at the shared callback boundary. Normalizing content blocks inside each provider would duplicate logic, produce inconsistent behavior across integrations, and in some cases lose required provider metadata such as Gemini thought signatures. ## Changes - Update the callback contract so streamed tokens can be either plain text or structured content blocks - Carry structured streamed content through tracing and event/log streaming paths without forcing provider data into text too early - Keep built-in text-oriented streaming callbacks working by converting structured tokens only at the display/queue boundary - Drop the now-incorrect `cast("str", ...)` on streamed content in `BaseChatModel` so the producer side matches the widened callback signature instead of asserting a string it doesn't always have (no runtime change — `cast` is erased) - Align Anthropic and Mistral content typing with the structured content shapes already used by chat model messages - Update callback tests to reflect that not every streamed value is text ## Compatibility No runtime behavior change: no producer emits anything it wasn't already emitting, and widening a parameter type is safe for existing callers and handlers that pass or receive `str`. The one caveat is downstream code that subclasses a callback handler or tracer and overrides `on_llm_new_token` with a `token: str` annotation — under strict type checking that override is now narrower than the base and will be flagged as incompatible with the supertype. Such code still runs unchanged; the fix is to widen the annotation to match.
This commit is contained in:
@@ -64,7 +64,7 @@ class LLMManagerMixin:
|
||||
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
token: str | list[str | dict[str, Any]],
|
||||
*,
|
||||
chunk: GenerationChunk | ChatGenerationChunk | None = None,
|
||||
run_id: UUID,
|
||||
@@ -79,7 +79,7 @@ class LLMManagerMixin:
|
||||
For both chat models and non-chat models (legacy text completion LLMs).
|
||||
|
||||
Args:
|
||||
token: The new token.
|
||||
token: The new token, or a list of content blocks.
|
||||
chunk: The new generated chunk, containing content and other information.
|
||||
run_id: The ID of the current run.
|
||||
parent_run_id: The ID of the parent run.
|
||||
@@ -631,7 +631,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
token: str | list[str | dict[str, Any]],
|
||||
*,
|
||||
chunk: GenerationChunk | ChatGenerationChunk | None = None,
|
||||
run_id: UUID,
|
||||
@@ -644,7 +644,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
For both chat models and non-chat models (legacy text completion LLMs).
|
||||
|
||||
Args:
|
||||
token: The new token.
|
||||
token: The new token, or a list of content blocks.
|
||||
chunk: The new generated chunk, containing content and other information.
|
||||
run_id: The ID of the current run.
|
||||
parent_run_id: The ID of the parent run.
|
||||
|
||||
@@ -668,7 +668,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
token: str | list[str | dict[str, Any]],
|
||||
*,
|
||||
chunk: GenerationChunk | ChatGenerationChunk | None = None,
|
||||
**kwargs: Any,
|
||||
@@ -676,7 +676,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
"""Run when LLM generates a new token.
|
||||
|
||||
Args:
|
||||
token: The new token.
|
||||
token: The new token, or a list of content blocks.
|
||||
chunk: The chunk.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
@@ -787,7 +787,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
token: str | list[str | dict[str, Any]],
|
||||
*,
|
||||
chunk: GenerationChunk | ChatGenerationChunk | None = None,
|
||||
**kwargs: Any,
|
||||
@@ -795,7 +795,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
"""Run when LLM generates a new token.
|
||||
|
||||
Args:
|
||||
token: The new token.
|
||||
token: The new token, or a list of content blocks.
|
||||
chunk: The chunk.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
|
||||
@@ -47,14 +47,16 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""
|
||||
|
||||
@override
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
def on_llm_new_token(
|
||||
self, token: str | list[str | dict[str, Any]], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled.
|
||||
|
||||
Args:
|
||||
token: The new token.
|
||||
token: The new token, or a list of content blocks.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
sys.stdout.write(token)
|
||||
sys.stdout.write(str(token))
|
||||
sys.stdout.flush()
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
|
||||
@@ -792,9 +792,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
index += 1
|
||||
if "index" not in block:
|
||||
block["index"] = index
|
||||
run_manager.on_llm_new_token(
|
||||
cast("str", chunk.message.content), chunk=chunk
|
||||
)
|
||||
run_manager.on_llm_new_token(chunk.message.content, chunk=chunk)
|
||||
chunks.append(chunk)
|
||||
yield cast("AIMessageChunk", chunk.message)
|
||||
yielded = True
|
||||
@@ -927,9 +925,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
index += 1
|
||||
if "index" not in block:
|
||||
block["index"] = index
|
||||
await run_manager.on_llm_new_token(
|
||||
cast("str", chunk.message.content), chunk=chunk
|
||||
)
|
||||
await run_manager.on_llm_new_token(chunk.message.content, chunk=chunk)
|
||||
chunks.append(chunk)
|
||||
yield cast("AIMessageChunk", chunk.message)
|
||||
yielded = True
|
||||
@@ -1972,9 +1968,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
if run_manager:
|
||||
if chunk.message.id is None:
|
||||
chunk.message.id = run_id
|
||||
run_manager.on_llm_new_token(
|
||||
cast("str", chunk.message.content), chunk=chunk
|
||||
)
|
||||
run_manager.on_llm_new_token(chunk.message.content, chunk=chunk)
|
||||
chunks.append(chunk)
|
||||
yielded = True
|
||||
|
||||
@@ -2130,7 +2124,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
if chunk.message.id is None:
|
||||
chunk.message.id = run_id
|
||||
await run_manager.on_llm_new_token(
|
||||
cast("str", chunk.message.content), chunk=chunk
|
||||
chunk.message.content, chunk=chunk
|
||||
)
|
||||
chunks.append(chunk)
|
||||
yielded = True
|
||||
|
||||
@@ -149,7 +149,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
@override
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
token: str | list[str | dict[str, Any]],
|
||||
*,
|
||||
chunk: GenerationChunk | ChatGenerationChunk | None = None,
|
||||
run_id: UUID,
|
||||
@@ -161,7 +161,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
Only available when streaming is enabled.
|
||||
|
||||
Args:
|
||||
token: The token.
|
||||
token: The token, or a list of content blocks for structured output.
|
||||
chunk: The chunk.
|
||||
run_id: The run ID.
|
||||
parent_run_id: The parent run ID.
|
||||
@@ -645,7 +645,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
||||
@override
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
token: str | list[str | dict[str, Any]],
|
||||
*,
|
||||
chunk: GenerationChunk | ChatGenerationChunk | None = None,
|
||||
run_id: UUID,
|
||||
@@ -919,7 +919,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
||||
async def _on_llm_new_token(
|
||||
self,
|
||||
run: Run,
|
||||
token: str,
|
||||
token: str | list[str | dict[str, Any]],
|
||||
chunk: GenerationChunk | ChatGenerationChunk | None,
|
||||
) -> None:
|
||||
"""Process new LLM token."""
|
||||
|
||||
@@ -242,7 +242,7 @@ class _TracerCore(ABC):
|
||||
|
||||
def _llm_run_with_token_event(
|
||||
self,
|
||||
token: str,
|
||||
token: str | list[str | dict[str, Any]],
|
||||
run_id: UUID,
|
||||
chunk: GenerationChunk | ChatGenerationChunk | None = None,
|
||||
parent_run_id: UUID | None = None,
|
||||
@@ -602,14 +602,14 @@ class _TracerCore(ABC):
|
||||
def _on_llm_new_token(
|
||||
self,
|
||||
run: Run,
|
||||
token: str,
|
||||
token: str | list[str | dict[str, Any]],
|
||||
chunk: GenerationChunk | ChatGenerationChunk | None,
|
||||
) -> Coroutine[Any, Any, None] | None:
|
||||
"""Process new LLM token.
|
||||
|
||||
Args:
|
||||
run: The LLM run.
|
||||
token: The new token.
|
||||
token: The new token, or a list of content blocks.
|
||||
chunk: Optional chunk.
|
||||
"""
|
||||
_ = (run, token, chunk)
|
||||
|
||||
@@ -429,7 +429,7 @@ class _AstreamEventsCallbackHandler(
|
||||
@override
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
token: str | list[str | dict[str, Any]],
|
||||
*,
|
||||
chunk: GenerationChunk | ChatGenerationChunk | None = None,
|
||||
run_id: UUID,
|
||||
@@ -465,7 +465,8 @@ class _AstreamEventsCallbackHandler(
|
||||
elif run_info["run_type"] == "llm":
|
||||
event = "on_llm_stream"
|
||||
if chunk is None:
|
||||
chunk_ = GenerationChunk(text=token)
|
||||
text = token if isinstance(token, str) else ""
|
||||
chunk_ = GenerationChunk(text=text)
|
||||
else:
|
||||
chunk_ = cast("GenerationChunk", chunk)
|
||||
else:
|
||||
|
||||
@@ -359,7 +359,7 @@ class LangChainTracer(BaseTracer):
|
||||
@override
|
||||
def _llm_run_with_token_event(
|
||||
self,
|
||||
token: str,
|
||||
token: str | list[str | dict[str, Any]],
|
||||
run_id: UUID,
|
||||
chunk: GenerationChunk | ChatGenerationChunk | None = None,
|
||||
parent_run_id: UUID | None = None,
|
||||
|
||||
@@ -537,7 +537,7 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler[Any]):
|
||||
def _on_llm_new_token(
|
||||
self,
|
||||
run: Run,
|
||||
token: str,
|
||||
token: str | list[str | dict[str, Any]],
|
||||
chunk: GenerationChunk | ChatGenerationChunk | None,
|
||||
) -> None:
|
||||
"""Process new LLM token."""
|
||||
|
||||
@@ -33,7 +33,7 @@ class MyCustomAsyncHandler(AsyncCallbackHandler):
|
||||
@override
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
token: str | list[str | dict[str, Any]],
|
||||
*,
|
||||
chunk: GenerationChunk | ChatGenerationChunk | None = None,
|
||||
run_id: UUID,
|
||||
|
||||
@@ -182,7 +182,7 @@ async def test_callback_handlers() -> None:
|
||||
@override
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
token: str | list[str | dict[str, Any]],
|
||||
*,
|
||||
chunk: GenerationChunk | ChatGenerationChunk | None = None,
|
||||
run_id: UUID,
|
||||
@@ -190,7 +190,8 @@ async def test_callback_handlers() -> None:
|
||||
tags: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.store.append(token)
|
||||
if isinstance(token, str):
|
||||
self.store.append(token)
|
||||
|
||||
infinite_cycle = cycle(
|
||||
[
|
||||
|
||||
@@ -52,19 +52,22 @@ lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/langchain -
|
||||
lint_package: PYTHON_FILES=langchain_classic
|
||||
lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
UV_RUN_LINT = uv run --all-groups
|
||||
UV_RUN_TYPE = uv run --all-groups
|
||||
lint_package lint_tests: UV_RUN_LINT = uv run --group lint
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
./scripts/lint_imports.sh
|
||||
[ "$(PYTHON_FILES)" = "" ] || uv run --group lint --group typing ruff check $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || uv run --group lint --group typing ruff format $(PYTHON_FILES) --diff
|
||||
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && uv run --group lint --group typing mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
[ "$(PYTHON_FILES)" = "" ] || $(UV_RUN_LINT) ruff check $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || $(UV_RUN_LINT) ruff format $(PYTHON_FILES) --diff
|
||||
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && $(UV_RUN_TYPE) mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
|
||||
type:
|
||||
mkdir -p $(MYPY_CACHE) && uv run --group lint --group typing mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
mkdir -p $(MYPY_CACHE) && $(UV_RUN_TYPE) mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
|
||||
format format_diff:
|
||||
[ "$(PYTHON_FILES)" = "" ] || uv run --group lint --group typing ruff format $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || uv run --group lint --group typing ruff check --fix $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || $(UV_RUN_LINT) ruff format $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || $(UV_RUN_LINT) ruff check --fix $(PYTHON_FILES)
|
||||
|
||||
######################
|
||||
# HELP
|
||||
|
||||
@@ -39,9 +39,12 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
|
||||
self.done.clear()
|
||||
|
||||
@override
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
if token is not None and token != "":
|
||||
self.queue.put_nowait(token)
|
||||
async def on_llm_new_token(
|
||||
self, token: str | list[str | dict[str, Any]], **kwargs: Any
|
||||
) -> None:
|
||||
token_str = token if isinstance(token, str) else str(token)
|
||||
if token_str != "":
|
||||
self.queue.put_nowait(token_str)
|
||||
|
||||
@override
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
|
||||
@@ -81,9 +81,13 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
self.done.set()
|
||||
|
||||
@override
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
async def on_llm_new_token(
|
||||
self, token: str | list[str | dict[str, Any]], **kwargs: Any
|
||||
) -> None:
|
||||
token_str = token if isinstance(token, str) else str(token)
|
||||
|
||||
# Remember the last n tokens, where n = len(answer_prefix_tokens)
|
||||
self.append_to_last_tokens(token)
|
||||
self.append_to_last_tokens(token_str)
|
||||
|
||||
# Check if the last n tokens match the answer_prefix_tokens list ...
|
||||
if self.check_if_answer_reached():
|
||||
@@ -95,4 +99,4 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
|
||||
# If yes, then put tokens from now on
|
||||
if self.answer_reached:
|
||||
self.queue.put_nowait(token)
|
||||
self.queue.put_nowait(token_str)
|
||||
|
||||
@@ -76,10 +76,14 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
||||
self.answer_reached = False
|
||||
|
||||
@override
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
def on_llm_new_token(
|
||||
self, token: str | list[str | dict[str, Any]], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
token_str = token if isinstance(token, str) else str(token)
|
||||
|
||||
# Remember the last n tokens, where n = len(answer_prefix_tokens)
|
||||
self.append_to_last_tokens(token)
|
||||
self.append_to_last_tokens(token_str)
|
||||
|
||||
# Check if the last n tokens match the answer_prefix_tokens list ...
|
||||
if self.check_if_answer_reached():
|
||||
@@ -92,5 +96,5 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
||||
|
||||
# ... if yes, then print tokens from now on
|
||||
if self.answer_reached:
|
||||
sys.stdout.write(token)
|
||||
sys.stdout.write(token_str)
|
||||
sys.stdout.flush()
|
||||
|
||||
@@ -173,7 +173,7 @@ async def test_callback_handlers() -> None:
|
||||
@override
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
token: str | list[str | dict[str, Any]],
|
||||
*,
|
||||
chunk: GenerationChunk | ChatGenerationChunk | None = None,
|
||||
run_id: UUID,
|
||||
@@ -181,7 +181,8 @@ async def test_callback_handlers() -> None:
|
||||
tags: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.store.append(token)
|
||||
if isinstance(token, str):
|
||||
self.store.append(token)
|
||||
|
||||
infinite_cycle = cycle(
|
||||
[
|
||||
|
||||
@@ -77,7 +77,7 @@ def _extract_tool_calls_from_message(message: AIMessage) -> list[ToolCall]:
|
||||
return extract_tool_calls(message.content)
|
||||
|
||||
|
||||
def extract_tool_calls(content: str | list[str | dict]) -> list[ToolCall]:
|
||||
def extract_tool_calls(content: str | list[str | dict[str, Any]]) -> list[ToolCall]:
|
||||
"""Extract tool calls from a list of content blocks."""
|
||||
if isinstance(content, list):
|
||||
tool_calls = []
|
||||
|
||||
@@ -737,7 +737,7 @@ def test__format_messages_with_tool_calls() -> None:
|
||||
assert expected == actual
|
||||
|
||||
# Check handling of empty AIMessage
|
||||
empty_contents: list[str | list[str | dict]] = ["", []]
|
||||
empty_contents: list[str | list[str | dict[str, Any]]] = ["", []]
|
||||
for empty_content in empty_contents:
|
||||
## Permit message in final position
|
||||
_, anthropic_messages = _format_messages([human, AIMessage(empty_content)])
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages import content as types
|
||||
from langchain_core.messages.block_translators import register_translator
|
||||
@@ -10,7 +12,7 @@ from langchain_core.messages.block_translators import register_translator
|
||||
def _convert_from_v1_to_mistral(
|
||||
content: list[types.ContentBlock],
|
||||
model_provider: str | None,
|
||||
) -> str | list[str | dict]:
|
||||
) -> str | list[str | dict[str, Any]]:
|
||||
new_content: list = []
|
||||
for block in content:
|
||||
if block["type"] == "text":
|
||||
|
||||
@@ -339,8 +339,11 @@ async def mock_chat_astream(*args: Any, **kwargs: Any) -> AsyncGenerator:
|
||||
class MyCustomHandler(BaseCallbackHandler):
|
||||
last_token: str = ""
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
self.last_token = token
|
||||
def on_llm_new_token(
|
||||
self, token: str | list[str | dict[str, Any]], **kwargs: Any
|
||||
) -> None:
|
||||
if isinstance(token, str):
|
||||
self.last_token = token
|
||||
|
||||
|
||||
@patch(
|
||||
|
||||
Reference in New Issue
Block a user