diff --git a/libs/core/langchain_core/callbacks/base.py b/libs/core/langchain_core/callbacks/base.py index 23419319040..6e4df3dd891 100644 --- a/libs/core/langchain_core/callbacks/base.py +++ b/libs/core/langchain_core/callbacks/base.py @@ -64,7 +64,7 @@ class LLMManagerMixin: def on_llm_new_token( self, - token: str, + token: str | list[str | dict[str, Any]], *, chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, @@ -79,7 +79,7 @@ class LLMManagerMixin: For both chat models and non-chat models (legacy text completion LLMs). Args: - token: The new token. + token: The new token, or a list of content blocks. chunk: The new generated chunk, containing content and other information. run_id: The ID of the current run. parent_run_id: The ID of the parent run. @@ -631,7 +631,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): async def on_llm_new_token( self, - token: str, + token: str | list[str | dict[str, Any]], *, chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, @@ -644,7 +644,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): For both chat models and non-chat models (legacy text completion LLMs). Args: - token: The new token. + token: The new token, or a list of content blocks. chunk: The new generated chunk, containing content and other information. run_id: The ID of the current run. parent_run_id: The ID of the parent run. diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 40fd907272a..8d46f74eea2 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -668,7 +668,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): def on_llm_new_token( self, - token: str, + token: str | list[str | dict[str, Any]], *, chunk: GenerationChunk | ChatGenerationChunk | None = None, **kwargs: Any, @@ -676,7 +676,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): """Run when LLM generates a new token. Args: - token: The new token. + token: The new token, or a list of content blocks. chunk: The chunk. **kwargs: Additional keyword arguments. @@ -787,7 +787,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): async def on_llm_new_token( self, - token: str, + token: str | list[str | dict[str, Any]], *, chunk: GenerationChunk | ChatGenerationChunk | None = None, **kwargs: Any, @@ -795,7 +795,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): """Run when LLM generates a new token. Args: - token: The new token. + token: The new token, or a list of content blocks. chunk: The chunk. **kwargs: Additional keyword arguments. diff --git a/libs/core/langchain_core/callbacks/streaming_stdout.py b/libs/core/langchain_core/callbacks/streaming_stdout.py index 920fef80bde..5cba233ee5a 100644 --- a/libs/core/langchain_core/callbacks/streaming_stdout.py +++ b/libs/core/langchain_core/callbacks/streaming_stdout.py @@ -47,14 +47,16 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler): """ @override - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + def on_llm_new_token( + self, token: str | list[str | dict[str, Any]], **kwargs: Any + ) -> None: """Run on new LLM token. Only available when streaming is enabled. Args: - token: The new token. + token: The new token, or a list of content blocks. **kwargs: Additional keyword arguments. """ - sys.stdout.write(token) + sys.stdout.write(str(token)) sys.stdout.flush() def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index cc946a919be..7c21d12110f 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -792,9 +792,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): index += 1 if "index" not in block: block["index"] = index - run_manager.on_llm_new_token( - cast("str", chunk.message.content), chunk=chunk - ) + run_manager.on_llm_new_token(chunk.message.content, chunk=chunk) chunks.append(chunk) yield cast("AIMessageChunk", chunk.message) yielded = True @@ -927,9 +925,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): index += 1 if "index" not in block: block["index"] = index - await run_manager.on_llm_new_token( - cast("str", chunk.message.content), chunk=chunk - ) + await run_manager.on_llm_new_token(chunk.message.content, chunk=chunk) chunks.append(chunk) yield cast("AIMessageChunk", chunk.message) yielded = True @@ -1972,9 +1968,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): if run_manager: if chunk.message.id is None: chunk.message.id = run_id - run_manager.on_llm_new_token( - cast("str", chunk.message.content), chunk=chunk - ) + run_manager.on_llm_new_token(chunk.message.content, chunk=chunk) chunks.append(chunk) yielded = True @@ -2130,7 +2124,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): if chunk.message.id is None: chunk.message.id = run_id await run_manager.on_llm_new_token( - cast("str", chunk.message.content), chunk=chunk + chunk.message.content, chunk=chunk ) chunks.append(chunk) yielded = True diff --git a/libs/core/langchain_core/tracers/base.py b/libs/core/langchain_core/tracers/base.py index 0c49530a30b..70ba436c8d2 100644 --- a/libs/core/langchain_core/tracers/base.py +++ b/libs/core/langchain_core/tracers/base.py @@ -149,7 +149,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): @override def on_llm_new_token( self, - token: str, + token: str | list[str | dict[str, Any]], *, chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, @@ -161,7 +161,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): Only available when streaming is enabled. Args: - token: The token. + token: The token, or a list of content blocks for structured output. chunk: The chunk. run_id: The run ID. parent_run_id: The parent run ID. @@ -645,7 +645,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): @override async def on_llm_new_token( self, - token: str, + token: str | list[str | dict[str, Any]], *, chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, @@ -919,7 +919,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): async def _on_llm_new_token( self, run: Run, - token: str, + token: str | list[str | dict[str, Any]], chunk: GenerationChunk | ChatGenerationChunk | None, ) -> None: """Process new LLM token.""" diff --git a/libs/core/langchain_core/tracers/core.py b/libs/core/langchain_core/tracers/core.py index c6f6cca167c..5e073259d98 100644 --- a/libs/core/langchain_core/tracers/core.py +++ b/libs/core/langchain_core/tracers/core.py @@ -242,7 +242,7 @@ class _TracerCore(ABC): def _llm_run_with_token_event( self, - token: str, + token: str | list[str | dict[str, Any]], run_id: UUID, chunk: GenerationChunk | ChatGenerationChunk | None = None, parent_run_id: UUID | None = None, @@ -602,14 +602,14 @@ class _TracerCore(ABC): def _on_llm_new_token( self, run: Run, - token: str, + token: str | list[str | dict[str, Any]], chunk: GenerationChunk | ChatGenerationChunk | None, ) -> Coroutine[Any, Any, None] | None: """Process new LLM token. Args: run: The LLM run. - token: The new token. + token: The new token, or a list of content blocks. chunk: Optional chunk. """ _ = (run, token, chunk) diff --git a/libs/core/langchain_core/tracers/event_stream.py b/libs/core/langchain_core/tracers/event_stream.py index cba55b8bb9c..7d5db9cc8f6 100644 --- a/libs/core/langchain_core/tracers/event_stream.py +++ b/libs/core/langchain_core/tracers/event_stream.py @@ -429,7 +429,7 @@ class _AstreamEventsCallbackHandler( @override async def on_llm_new_token( self, - token: str, + token: str | list[str | dict[str, Any]], *, chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, @@ -465,7 +465,8 @@ class _AstreamEventsCallbackHandler( elif run_info["run_type"] == "llm": event = "on_llm_stream" if chunk is None: - chunk_ = GenerationChunk(text=token) + text = token if isinstance(token, str) else "" + chunk_ = GenerationChunk(text=text) else: chunk_ = cast("GenerationChunk", chunk) else: diff --git a/libs/core/langchain_core/tracers/langchain.py b/libs/core/langchain_core/tracers/langchain.py index 6295a2f034b..9b3bbd37527 100644 --- a/libs/core/langchain_core/tracers/langchain.py +++ b/libs/core/langchain_core/tracers/langchain.py @@ -359,7 +359,7 @@ class LangChainTracer(BaseTracer): @override def _llm_run_with_token_event( self, - token: str, + token: str | list[str | dict[str, Any]], run_id: UUID, chunk: GenerationChunk | ChatGenerationChunk | None = None, parent_run_id: UUID | None = None, diff --git a/libs/core/langchain_core/tracers/log_stream.py b/libs/core/langchain_core/tracers/log_stream.py index ebafc14b0e9..cf94189c9d5 100644 --- a/libs/core/langchain_core/tracers/log_stream.py +++ b/libs/core/langchain_core/tracers/log_stream.py @@ -537,7 +537,7 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler[Any]): def _on_llm_new_token( self, run: Run, - token: str, + token: str | list[str | dict[str, Any]], chunk: GenerationChunk | ChatGenerationChunk | None, ) -> None: """Process new LLM token.""" diff --git a/libs/core/tests/benchmarks/test_async_callbacks.py b/libs/core/tests/benchmarks/test_async_callbacks.py index f6774ae3620..8f5582a96a7 100644 --- a/libs/core/tests/benchmarks/test_async_callbacks.py +++ b/libs/core/tests/benchmarks/test_async_callbacks.py @@ -33,7 +33,7 @@ class MyCustomAsyncHandler(AsyncCallbackHandler): @override async def on_llm_new_token( self, - token: str, + token: str | list[str | dict[str, Any]], *, chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, diff --git a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py index bf5629a12c5..786c7bcfea3 100644 --- a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py +++ b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py @@ -182,7 +182,7 @@ async def test_callback_handlers() -> None: @override async def on_llm_new_token( self, - token: str, + token: str | list[str | dict[str, Any]], *, chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, @@ -190,7 +190,8 @@ async def test_callback_handlers() -> None: tags: list[str] | None = None, **kwargs: Any, ) -> None: - self.store.append(token) + if isinstance(token, str): + self.store.append(token) infinite_cycle = cycle( [ diff --git a/libs/langchain/Makefile b/libs/langchain/Makefile index 421d4a5b755..92f0dcaa59e 100644 --- a/libs/langchain/Makefile +++ b/libs/langchain/Makefile @@ -52,19 +52,22 @@ lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/langchain - lint_package: PYTHON_FILES=langchain_classic lint_tests: PYTHON_FILES=tests lint_tests: MYPY_CACHE=.mypy_cache_test +UV_RUN_LINT = uv run --all-groups +UV_RUN_TYPE = uv run --all-groups +lint_package lint_tests: UV_RUN_LINT = uv run --group lint lint lint_diff lint_package lint_tests: ./scripts/lint_imports.sh - [ "$(PYTHON_FILES)" = "" ] || uv run --group lint --group typing ruff check $(PYTHON_FILES) - [ "$(PYTHON_FILES)" = "" ] || uv run --group lint --group typing ruff format $(PYTHON_FILES) --diff - [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && uv run --group lint --group typing mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) + [ "$(PYTHON_FILES)" = "" ] || $(UV_RUN_LINT) ruff check $(PYTHON_FILES) + [ "$(PYTHON_FILES)" = "" ] || $(UV_RUN_LINT) ruff format $(PYTHON_FILES) --diff + [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && $(UV_RUN_TYPE) mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) type: - mkdir -p $(MYPY_CACHE) && uv run --group lint --group typing mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) + mkdir -p $(MYPY_CACHE) && $(UV_RUN_TYPE) mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) format format_diff: - [ "$(PYTHON_FILES)" = "" ] || uv run --group lint --group typing ruff format $(PYTHON_FILES) - [ "$(PYTHON_FILES)" = "" ] || uv run --group lint --group typing ruff check --fix $(PYTHON_FILES) + [ "$(PYTHON_FILES)" = "" ] || $(UV_RUN_LINT) ruff format $(PYTHON_FILES) + [ "$(PYTHON_FILES)" = "" ] || $(UV_RUN_LINT) ruff check --fix $(PYTHON_FILES) ###################### # HELP diff --git a/libs/langchain/langchain_classic/callbacks/streaming_aiter.py b/libs/langchain/langchain_classic/callbacks/streaming_aiter.py index 0811e86a3e7..cb0ab1f86dc 100644 --- a/libs/langchain/langchain_classic/callbacks/streaming_aiter.py +++ b/libs/langchain/langchain_classic/callbacks/streaming_aiter.py @@ -39,9 +39,12 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler): self.done.clear() @override - async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - if token is not None and token != "": - self.queue.put_nowait(token) + async def on_llm_new_token( + self, token: str | list[str | dict[str, Any]], **kwargs: Any + ) -> None: + token_str = token if isinstance(token, str) else str(token) + if token_str != "": + self.queue.put_nowait(token_str) @override async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: diff --git a/libs/langchain/langchain_classic/callbacks/streaming_aiter_final_only.py b/libs/langchain/langchain_classic/callbacks/streaming_aiter_final_only.py index 744475e6628..37b3ce93a9a 100644 --- a/libs/langchain/langchain_classic/callbacks/streaming_aiter_final_only.py +++ b/libs/langchain/langchain_classic/callbacks/streaming_aiter_final_only.py @@ -81,9 +81,13 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler): self.done.set() @override - async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + async def on_llm_new_token( + self, token: str | list[str | dict[str, Any]], **kwargs: Any + ) -> None: + token_str = token if isinstance(token, str) else str(token) + # Remember the last n tokens, where n = len(answer_prefix_tokens) - self.append_to_last_tokens(token) + self.append_to_last_tokens(token_str) # Check if the last n tokens match the answer_prefix_tokens list ... if self.check_if_answer_reached(): @@ -95,4 +99,4 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler): # If yes, then put tokens from now on if self.answer_reached: - self.queue.put_nowait(token) + self.queue.put_nowait(token_str) diff --git a/libs/langchain/langchain_classic/callbacks/streaming_stdout_final_only.py b/libs/langchain/langchain_classic/callbacks/streaming_stdout_final_only.py index e8eee519b3e..62ef56ddbdd 100644 --- a/libs/langchain/langchain_classic/callbacks/streaming_stdout_final_only.py +++ b/libs/langchain/langchain_classic/callbacks/streaming_stdout_final_only.py @@ -76,10 +76,14 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler): self.answer_reached = False @override - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + def on_llm_new_token( + self, token: str | list[str | dict[str, Any]], **kwargs: Any + ) -> None: """Run on new LLM token. Only available when streaming is enabled.""" + token_str = token if isinstance(token, str) else str(token) + # Remember the last n tokens, where n = len(answer_prefix_tokens) - self.append_to_last_tokens(token) + self.append_to_last_tokens(token_str) # Check if the last n tokens match the answer_prefix_tokens list ... if self.check_if_answer_reached(): @@ -92,5 +96,5 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler): # ... if yes, then print tokens from now on if self.answer_reached: - sys.stdout.write(token) + sys.stdout.write(token_str) sys.stdout.flush() diff --git a/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py b/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py index 88cd91c3d13..57a2f884985 100644 --- a/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py +++ b/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py @@ -173,7 +173,7 @@ async def test_callback_handlers() -> None: @override async def on_llm_new_token( self, - token: str, + token: str | list[str | dict[str, Any]], *, chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, @@ -181,7 +181,8 @@ async def test_callback_handlers() -> None: tags: list[str] | None = None, **kwargs: Any, ) -> None: - self.store.append(token) + if isinstance(token, str): + self.store.append(token) infinite_cycle = cycle( [ diff --git a/libs/partners/anthropic/langchain_anthropic/output_parsers.py b/libs/partners/anthropic/langchain_anthropic/output_parsers.py index 57e2121c5e5..60b189de1d6 100644 --- a/libs/partners/anthropic/langchain_anthropic/output_parsers.py +++ b/libs/partners/anthropic/langchain_anthropic/output_parsers.py @@ -77,7 +77,7 @@ def _extract_tool_calls_from_message(message: AIMessage) -> list[ToolCall]: return extract_tool_calls(message.content) -def extract_tool_calls(content: str | list[str | dict]) -> list[ToolCall]: +def extract_tool_calls(content: str | list[str | dict[str, Any]]) -> list[ToolCall]: """Extract tool calls from a list of content blocks.""" if isinstance(content, list): tool_calls = [] diff --git a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py index fe1f217f851..0dbdceb483e 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py @@ -737,7 +737,7 @@ def test__format_messages_with_tool_calls() -> None: assert expected == actual # Check handling of empty AIMessage - empty_contents: list[str | list[str | dict]] = ["", []] + empty_contents: list[str | list[str | dict[str, Any]]] = ["", []] for empty_content in empty_contents: ## Permit message in final position _, anthropic_messages = _format_messages([human, AIMessage(empty_content)]) diff --git a/libs/partners/mistralai/langchain_mistralai/_compat.py b/libs/partners/mistralai/langchain_mistralai/_compat.py index f716300dccc..16f525791f7 100644 --- a/libs/partners/mistralai/langchain_mistralai/_compat.py +++ b/libs/partners/mistralai/langchain_mistralai/_compat.py @@ -2,6 +2,8 @@ from __future__ import annotations +from typing import Any + from langchain_core.messages import AIMessage, AIMessageChunk from langchain_core.messages import content as types from langchain_core.messages.block_translators import register_translator @@ -10,7 +12,7 @@ from langchain_core.messages.block_translators import register_translator def _convert_from_v1_to_mistral( content: list[types.ContentBlock], model_provider: str | None, -) -> str | list[str | dict]: +) -> str | list[str | dict[str, Any]]: new_content: list = [] for block in content: if block["type"] == "text": diff --git a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py index 90be8938812..38de8cdb074 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py @@ -339,8 +339,11 @@ async def mock_chat_astream(*args: Any, **kwargs: Any) -> AsyncGenerator: class MyCustomHandler(BaseCallbackHandler): last_token: str = "" - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - self.last_token = token + def on_llm_new_token( + self, token: str | list[str | dict[str, Any]], **kwargs: Any + ) -> None: + if isinstance(token, str): + self.last_token = token @patch(