From f89f4c5afe5ef1044195469d6647c7c51869296e Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Wed, 10 Jun 2026 16:59:08 -0400 Subject: [PATCH] fix(core): support content block tokens in callbacks (#34739) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Supersedes #34727 Closes #30703 Related: * langchain-ai/langchain-google#1460 * langchain-ai/langchain-google#1501 Fixing this at the `langchain-core` callback layer instead of normalizing inside individual provider integrations, so structured streaming content is preserved consistently. --- Models are increasingly streaming structured content blocks instead of plain text tokens. For example, Gemini 3 can stream text as content-block lists, and Anthropic/tool-use flows can also produce non-text message content. Today those values already reach `on_llm_new_token`, but the callback API still advertises `token: str`, which makes custom callbacks, tracers, and streaming helpers assume every streamed value is text. User story: as a LangChain user building a streaming callback for chat models with tool calls, reasoning/thinking blocks, or provider-specific structured content, I need `on_llm_new_token` to accept the same content shape that chat model chunks can actually emit, so my callback can observe the stream without providers flattening or dropping non-text data. Fixing this in `langchain-core` makes the existing runtime behavior explicit at the shared callback boundary. Normalizing content blocks inside each provider would duplicate logic, produce inconsistent behavior across integrations, and in some cases lose required provider metadata such as Gemini thought signatures. ## Changes - Update the callback contract so streamed tokens can be either plain text or structured content blocks - Carry structured streamed content through tracing and event/log streaming paths without forcing provider data into text too early - Keep built-in text-oriented streaming callbacks working by converting structured tokens only at the display/queue boundary - Drop the now-incorrect `cast("str", ...)` on streamed content in `BaseChatModel` so the producer side matches the widened callback signature instead of asserting a string it doesn't always have (no runtime change — `cast` is erased) - Align Anthropic and Mistral content typing with the structured content shapes already used by chat model messages - Update callback tests to reflect that not every streamed value is text ## Compatibility No runtime behavior change: no producer emits anything it wasn't already emitting, and widening a parameter type is safe for existing callers and handlers that pass or receive `str`. The one caveat is downstream code that subclasses a callback handler or tracer and overrides `on_llm_new_token` with a `token: str` annotation — under strict type checking that override is now narrower than the base and will be flagged as incompatible with the supertype. Such code still runs unchanged; the fix is to widen the annotation to match. --- libs/core/langchain_core/callbacks/base.py | 8 ++++---- libs/core/langchain_core/callbacks/manager.py | 8 ++++---- .../langchain_core/callbacks/streaming_stdout.py | 8 +++++--- .../langchain_core/language_models/chat_models.py | 14 ++++---------- libs/core/langchain_core/tracers/base.py | 8 ++++---- libs/core/langchain_core/tracers/core.py | 6 +++--- libs/core/langchain_core/tracers/event_stream.py | 5 +++-- libs/core/langchain_core/tracers/langchain.py | 2 +- libs/core/langchain_core/tracers/log_stream.py | 2 +- .../core/tests/benchmarks/test_async_callbacks.py | 2 +- .../tests/unit_tests/fake/test_fake_chat_model.py | 5 +++-- libs/langchain/Makefile | 15 +++++++++------ .../callbacks/streaming_aiter.py | 9 ++++++--- .../callbacks/streaming_aiter_final_only.py | 10 +++++++--- .../callbacks/streaming_stdout_final_only.py | 10 +++++++--- .../tests/unit_tests/llms/test_fake_chat_model.py | 5 +++-- .../langchain_anthropic/output_parsers.py | 2 +- .../tests/unit_tests/test_chat_models.py | 2 +- .../mistralai/langchain_mistralai/_compat.py | 4 +++- .../tests/unit_tests/test_chat_models.py | 7 +++++-- 20 files changed, 75 insertions(+), 57 deletions(-) diff --git a/libs/core/langchain_core/callbacks/base.py b/libs/core/langchain_core/callbacks/base.py index 23419319040..6e4df3dd891 100644 --- a/libs/core/langchain_core/callbacks/base.py +++ b/libs/core/langchain_core/callbacks/base.py @@ -64,7 +64,7 @@ class LLMManagerMixin: def on_llm_new_token( self, - token: str, + token: str | list[str | dict[str, Any]], *, chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, @@ -79,7 +79,7 @@ class LLMManagerMixin: For both chat models and non-chat models (legacy text completion LLMs). Args: - token: The new token. + token: The new token, or a list of content blocks. chunk: The new generated chunk, containing content and other information. run_id: The ID of the current run. parent_run_id: The ID of the parent run. @@ -631,7 +631,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): async def on_llm_new_token( self, - token: str, + token: str | list[str | dict[str, Any]], *, chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, @@ -644,7 +644,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): For both chat models and non-chat models (legacy text completion LLMs). Args: - token: The new token. + token: The new token, or a list of content blocks. chunk: The new generated chunk, containing content and other information. run_id: The ID of the current run. parent_run_id: The ID of the parent run. diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 40fd907272a..8d46f74eea2 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -668,7 +668,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): def on_llm_new_token( self, - token: str, + token: str | list[str | dict[str, Any]], *, chunk: GenerationChunk | ChatGenerationChunk | None = None, **kwargs: Any, @@ -676,7 +676,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): """Run when LLM generates a new token. Args: - token: The new token. + token: The new token, or a list of content blocks. chunk: The chunk. **kwargs: Additional keyword arguments. @@ -787,7 +787,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): async def on_llm_new_token( self, - token: str, + token: str | list[str | dict[str, Any]], *, chunk: GenerationChunk | ChatGenerationChunk | None = None, **kwargs: Any, @@ -795,7 +795,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): """Run when LLM generates a new token. Args: - token: The new token. + token: The new token, or a list of content blocks. chunk: The chunk. **kwargs: Additional keyword arguments. diff --git a/libs/core/langchain_core/callbacks/streaming_stdout.py b/libs/core/langchain_core/callbacks/streaming_stdout.py index 920fef80bde..5cba233ee5a 100644 --- a/libs/core/langchain_core/callbacks/streaming_stdout.py +++ b/libs/core/langchain_core/callbacks/streaming_stdout.py @@ -47,14 +47,16 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler): """ @override - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + def on_llm_new_token( + self, token: str | list[str | dict[str, Any]], **kwargs: Any + ) -> None: """Run on new LLM token. Only available when streaming is enabled. Args: - token: The new token. + token: The new token, or a list of content blocks. **kwargs: Additional keyword arguments. """ - sys.stdout.write(token) + sys.stdout.write(str(token)) sys.stdout.flush() def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index cc946a919be..7c21d12110f 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -792,9 +792,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): index += 1 if "index" not in block: block["index"] = index - run_manager.on_llm_new_token( - cast("str", chunk.message.content), chunk=chunk - ) + run_manager.on_llm_new_token(chunk.message.content, chunk=chunk) chunks.append(chunk) yield cast("AIMessageChunk", chunk.message) yielded = True @@ -927,9 +925,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): index += 1 if "index" not in block: block["index"] = index - await run_manager.on_llm_new_token( - cast("str", chunk.message.content), chunk=chunk - ) + await run_manager.on_llm_new_token(chunk.message.content, chunk=chunk) chunks.append(chunk) yield cast("AIMessageChunk", chunk.message) yielded = True @@ -1972,9 +1968,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): if run_manager: if chunk.message.id is None: chunk.message.id = run_id - run_manager.on_llm_new_token( - cast("str", chunk.message.content), chunk=chunk - ) + run_manager.on_llm_new_token(chunk.message.content, chunk=chunk) chunks.append(chunk) yielded = True @@ -2130,7 +2124,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): if chunk.message.id is None: chunk.message.id = run_id await run_manager.on_llm_new_token( - cast("str", chunk.message.content), chunk=chunk + chunk.message.content, chunk=chunk ) chunks.append(chunk) yielded = True diff --git a/libs/core/langchain_core/tracers/base.py b/libs/core/langchain_core/tracers/base.py index 0c49530a30b..70ba436c8d2 100644 --- a/libs/core/langchain_core/tracers/base.py +++ b/libs/core/langchain_core/tracers/base.py @@ -149,7 +149,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): @override def on_llm_new_token( self, - token: str, + token: str | list[str | dict[str, Any]], *, chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, @@ -161,7 +161,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): Only available when streaming is enabled. Args: - token: The token. + token: The token, or a list of content blocks for structured output. chunk: The chunk. run_id: The run ID. parent_run_id: The parent run ID. @@ -645,7 +645,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): @override async def on_llm_new_token( self, - token: str, + token: str | list[str | dict[str, Any]], *, chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, @@ -919,7 +919,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): async def _on_llm_new_token( self, run: Run, - token: str, + token: str | list[str | dict[str, Any]], chunk: GenerationChunk | ChatGenerationChunk | None, ) -> None: """Process new LLM token.""" diff --git a/libs/core/langchain_core/tracers/core.py b/libs/core/langchain_core/tracers/core.py index c6f6cca167c..5e073259d98 100644 --- a/libs/core/langchain_core/tracers/core.py +++ b/libs/core/langchain_core/tracers/core.py @@ -242,7 +242,7 @@ class _TracerCore(ABC): def _llm_run_with_token_event( self, - token: str, + token: str | list[str | dict[str, Any]], run_id: UUID, chunk: GenerationChunk | ChatGenerationChunk | None = None, parent_run_id: UUID | None = None, @@ -602,14 +602,14 @@ class _TracerCore(ABC): def _on_llm_new_token( self, run: Run, - token: str, + token: str | list[str | dict[str, Any]], chunk: GenerationChunk | ChatGenerationChunk | None, ) -> Coroutine[Any, Any, None] | None: """Process new LLM token. Args: run: The LLM run. - token: The new token. + token: The new token, or a list of content blocks. chunk: Optional chunk. """ _ = (run, token, chunk) diff --git a/libs/core/langchain_core/tracers/event_stream.py b/libs/core/langchain_core/tracers/event_stream.py index cba55b8bb9c..7d5db9cc8f6 100644 --- a/libs/core/langchain_core/tracers/event_stream.py +++ b/libs/core/langchain_core/tracers/event_stream.py @@ -429,7 +429,7 @@ class _AstreamEventsCallbackHandler( @override async def on_llm_new_token( self, - token: str, + token: str | list[str | dict[str, Any]], *, chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, @@ -465,7 +465,8 @@ class _AstreamEventsCallbackHandler( elif run_info["run_type"] == "llm": event = "on_llm_stream" if chunk is None: - chunk_ = GenerationChunk(text=token) + text = token if isinstance(token, str) else "" + chunk_ = GenerationChunk(text=text) else: chunk_ = cast("GenerationChunk", chunk) else: diff --git a/libs/core/langchain_core/tracers/langchain.py b/libs/core/langchain_core/tracers/langchain.py index 6295a2f034b..9b3bbd37527 100644 --- a/libs/core/langchain_core/tracers/langchain.py +++ b/libs/core/langchain_core/tracers/langchain.py @@ -359,7 +359,7 @@ class LangChainTracer(BaseTracer): @override def _llm_run_with_token_event( self, - token: str, + token: str | list[str | dict[str, Any]], run_id: UUID, chunk: GenerationChunk | ChatGenerationChunk | None = None, parent_run_id: UUID | None = None, diff --git a/libs/core/langchain_core/tracers/log_stream.py b/libs/core/langchain_core/tracers/log_stream.py index ebafc14b0e9..cf94189c9d5 100644 --- a/libs/core/langchain_core/tracers/log_stream.py +++ b/libs/core/langchain_core/tracers/log_stream.py @@ -537,7 +537,7 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler[Any]): def _on_llm_new_token( self, run: Run, - token: str, + token: str | list[str | dict[str, Any]], chunk: GenerationChunk | ChatGenerationChunk | None, ) -> None: """Process new LLM token.""" diff --git a/libs/core/tests/benchmarks/test_async_callbacks.py b/libs/core/tests/benchmarks/test_async_callbacks.py index f6774ae3620..8f5582a96a7 100644 --- a/libs/core/tests/benchmarks/test_async_callbacks.py +++ b/libs/core/tests/benchmarks/test_async_callbacks.py @@ -33,7 +33,7 @@ class MyCustomAsyncHandler(AsyncCallbackHandler): @override async def on_llm_new_token( self, - token: str, + token: str | list[str | dict[str, Any]], *, chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, diff --git a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py index bf5629a12c5..786c7bcfea3 100644 --- a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py +++ b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py @@ -182,7 +182,7 @@ async def test_callback_handlers() -> None: @override async def on_llm_new_token( self, - token: str, + token: str | list[str | dict[str, Any]], *, chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, @@ -190,7 +190,8 @@ async def test_callback_handlers() -> None: tags: list[str] | None = None, **kwargs: Any, ) -> None: - self.store.append(token) + if isinstance(token, str): + self.store.append(token) infinite_cycle = cycle( [ diff --git a/libs/langchain/Makefile b/libs/langchain/Makefile index 421d4a5b755..92f0dcaa59e 100644 --- a/libs/langchain/Makefile +++ b/libs/langchain/Makefile @@ -52,19 +52,22 @@ lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/langchain - lint_package: PYTHON_FILES=langchain_classic lint_tests: PYTHON_FILES=tests lint_tests: MYPY_CACHE=.mypy_cache_test +UV_RUN_LINT = uv run --all-groups +UV_RUN_TYPE = uv run --all-groups +lint_package lint_tests: UV_RUN_LINT = uv run --group lint lint lint_diff lint_package lint_tests: ./scripts/lint_imports.sh - [ "$(PYTHON_FILES)" = "" ] || uv run --group lint --group typing ruff check $(PYTHON_FILES) - [ "$(PYTHON_FILES)" = "" ] || uv run --group lint --group typing ruff format $(PYTHON_FILES) --diff - [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && uv run --group lint --group typing mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) + [ "$(PYTHON_FILES)" = "" ] || $(UV_RUN_LINT) ruff check $(PYTHON_FILES) + [ "$(PYTHON_FILES)" = "" ] || $(UV_RUN_LINT) ruff format $(PYTHON_FILES) --diff + [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && $(UV_RUN_TYPE) mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) type: - mkdir -p $(MYPY_CACHE) && uv run --group lint --group typing mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) + mkdir -p $(MYPY_CACHE) && $(UV_RUN_TYPE) mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) format format_diff: - [ "$(PYTHON_FILES)" = "" ] || uv run --group lint --group typing ruff format $(PYTHON_FILES) - [ "$(PYTHON_FILES)" = "" ] || uv run --group lint --group typing ruff check --fix $(PYTHON_FILES) + [ "$(PYTHON_FILES)" = "" ] || $(UV_RUN_LINT) ruff format $(PYTHON_FILES) + [ "$(PYTHON_FILES)" = "" ] || $(UV_RUN_LINT) ruff check --fix $(PYTHON_FILES) ###################### # HELP diff --git a/libs/langchain/langchain_classic/callbacks/streaming_aiter.py b/libs/langchain/langchain_classic/callbacks/streaming_aiter.py index 0811e86a3e7..cb0ab1f86dc 100644 --- a/libs/langchain/langchain_classic/callbacks/streaming_aiter.py +++ b/libs/langchain/langchain_classic/callbacks/streaming_aiter.py @@ -39,9 +39,12 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler): self.done.clear() @override - async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - if token is not None and token != "": - self.queue.put_nowait(token) + async def on_llm_new_token( + self, token: str | list[str | dict[str, Any]], **kwargs: Any + ) -> None: + token_str = token if isinstance(token, str) else str(token) + if token_str != "": + self.queue.put_nowait(token_str) @override async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: diff --git a/libs/langchain/langchain_classic/callbacks/streaming_aiter_final_only.py b/libs/langchain/langchain_classic/callbacks/streaming_aiter_final_only.py index 744475e6628..37b3ce93a9a 100644 --- a/libs/langchain/langchain_classic/callbacks/streaming_aiter_final_only.py +++ b/libs/langchain/langchain_classic/callbacks/streaming_aiter_final_only.py @@ -81,9 +81,13 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler): self.done.set() @override - async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + async def on_llm_new_token( + self, token: str | list[str | dict[str, Any]], **kwargs: Any + ) -> None: + token_str = token if isinstance(token, str) else str(token) + # Remember the last n tokens, where n = len(answer_prefix_tokens) - self.append_to_last_tokens(token) + self.append_to_last_tokens(token_str) # Check if the last n tokens match the answer_prefix_tokens list ... if self.check_if_answer_reached(): @@ -95,4 +99,4 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler): # If yes, then put tokens from now on if self.answer_reached: - self.queue.put_nowait(token) + self.queue.put_nowait(token_str) diff --git a/libs/langchain/langchain_classic/callbacks/streaming_stdout_final_only.py b/libs/langchain/langchain_classic/callbacks/streaming_stdout_final_only.py index e8eee519b3e..62ef56ddbdd 100644 --- a/libs/langchain/langchain_classic/callbacks/streaming_stdout_final_only.py +++ b/libs/langchain/langchain_classic/callbacks/streaming_stdout_final_only.py @@ -76,10 +76,14 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler): self.answer_reached = False @override - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + def on_llm_new_token( + self, token: str | list[str | dict[str, Any]], **kwargs: Any + ) -> None: """Run on new LLM token. Only available when streaming is enabled.""" + token_str = token if isinstance(token, str) else str(token) + # Remember the last n tokens, where n = len(answer_prefix_tokens) - self.append_to_last_tokens(token) + self.append_to_last_tokens(token_str) # Check if the last n tokens match the answer_prefix_tokens list ... if self.check_if_answer_reached(): @@ -92,5 +96,5 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler): # ... if yes, then print tokens from now on if self.answer_reached: - sys.stdout.write(token) + sys.stdout.write(token_str) sys.stdout.flush() diff --git a/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py b/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py index 88cd91c3d13..57a2f884985 100644 --- a/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py +++ b/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py @@ -173,7 +173,7 @@ async def test_callback_handlers() -> None: @override async def on_llm_new_token( self, - token: str, + token: str | list[str | dict[str, Any]], *, chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, @@ -181,7 +181,8 @@ async def test_callback_handlers() -> None: tags: list[str] | None = None, **kwargs: Any, ) -> None: - self.store.append(token) + if isinstance(token, str): + self.store.append(token) infinite_cycle = cycle( [ diff --git a/libs/partners/anthropic/langchain_anthropic/output_parsers.py b/libs/partners/anthropic/langchain_anthropic/output_parsers.py index 57e2121c5e5..60b189de1d6 100644 --- a/libs/partners/anthropic/langchain_anthropic/output_parsers.py +++ b/libs/partners/anthropic/langchain_anthropic/output_parsers.py @@ -77,7 +77,7 @@ def _extract_tool_calls_from_message(message: AIMessage) -> list[ToolCall]: return extract_tool_calls(message.content) -def extract_tool_calls(content: str | list[str | dict]) -> list[ToolCall]: +def extract_tool_calls(content: str | list[str | dict[str, Any]]) -> list[ToolCall]: """Extract tool calls from a list of content blocks.""" if isinstance(content, list): tool_calls = [] diff --git a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py index fe1f217f851..0dbdceb483e 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py @@ -737,7 +737,7 @@ def test__format_messages_with_tool_calls() -> None: assert expected == actual # Check handling of empty AIMessage - empty_contents: list[str | list[str | dict]] = ["", []] + empty_contents: list[str | list[str | dict[str, Any]]] = ["", []] for empty_content in empty_contents: ## Permit message in final position _, anthropic_messages = _format_messages([human, AIMessage(empty_content)]) diff --git a/libs/partners/mistralai/langchain_mistralai/_compat.py b/libs/partners/mistralai/langchain_mistralai/_compat.py index f716300dccc..16f525791f7 100644 --- a/libs/partners/mistralai/langchain_mistralai/_compat.py +++ b/libs/partners/mistralai/langchain_mistralai/_compat.py @@ -2,6 +2,8 @@ from __future__ import annotations +from typing import Any + from langchain_core.messages import AIMessage, AIMessageChunk from langchain_core.messages import content as types from langchain_core.messages.block_translators import register_translator @@ -10,7 +12,7 @@ from langchain_core.messages.block_translators import register_translator def _convert_from_v1_to_mistral( content: list[types.ContentBlock], model_provider: str | None, -) -> str | list[str | dict]: +) -> str | list[str | dict[str, Any]]: new_content: list = [] for block in content: if block["type"] == "text": diff --git a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py index 90be8938812..38de8cdb074 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py @@ -339,8 +339,11 @@ async def mock_chat_astream(*args: Any, **kwargs: Any) -> AsyncGenerator: class MyCustomHandler(BaseCallbackHandler): last_token: str = "" - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - self.last_token = token + def on_llm_new_token( + self, token: str | list[str | dict[str, Any]], **kwargs: Any + ) -> None: + if isinstance(token, str): + self.last_token = token @patch(