diff --git a/libs/core/langchain_core/callbacks/base.py b/libs/core/langchain_core/callbacks/base.py index 220df5c77b7..6f03a3f6709 100644 --- a/libs/core/langchain_core/callbacks/base.py +++ b/libs/core/langchain_core/callbacks/base.py @@ -285,9 +285,29 @@ class CallbackManagerMixin: This method is called for chat models. If you're implementing a handler for a non-chat model, you should use `on_llm_start` instead. + !!! note + + When overriding this method, the signature **must** include the two + required positional arguments ``serialized`` and ``messages``. Avoid + using ``*args`` in your override — doing so causes an ``IndexError`` + in the fallback path when the callback system converts ``messages`` + to prompt strings for ``on_llm_start``. Always declare the + signature explicitly: + + .. code-block:: python + + def on_chat_model_start( + self, + serialized: dict[str, Any], + messages: list[list[BaseMessage]], + **kwargs: Any, + ) -> None: + raise NotImplementedError # triggers fallback to on_llm_start + Args: serialized: The serialized chat model. - messages: The messages. + messages: The messages. Must be a list of message lists — this is a + required positional argument and must be present in any override. run_id: The ID of the current run. parent_run_id: The ID of the parent run. tags: The tags. @@ -295,7 +315,7 @@ class CallbackManagerMixin: **kwargs: Additional keyword arguments. """ # NotImplementedError is thrown intentionally - # Callback handler will fall back to on_llm_start if this is exception is thrown + # Callback handler will fall back to on_llm_start if this exception is thrown msg = f"{self.__class__.__name__} does not implement `on_chat_model_start`" raise NotImplementedError(msg) @@ -534,9 +554,29 @@ class AsyncCallbackHandler(BaseCallbackHandler): This method is called for chat models. If you're implementing a handler for a non-chat model, you should use `on_llm_start` instead. + !!! note + + When overriding this method, the signature **must** include the two + required positional arguments ``serialized`` and ``messages``. Avoid + using ``*args`` in your override — doing so causes an ``IndexError`` + in the fallback path when the callback system converts ``messages`` + to prompt strings for ``on_llm_start``. Always declare the + signature explicitly: + + .. code-block:: python + + async def on_chat_model_start( + self, + serialized: dict[str, Any], + messages: list[list[BaseMessage]], + **kwargs: Any, + ) -> None: + raise NotImplementedError # triggers fallback to on_llm_start + Args: serialized: The serialized chat model. - messages: The messages. + messages: The messages. Must be a list of message lists — this is a + required positional argument and must be present in any override. run_id: The ID of the current run. parent_run_id: The ID of the parent run. tags: The tags. @@ -544,7 +584,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): **kwargs: Additional keyword arguments. """ # NotImplementedError is thrown intentionally - # Callback handler will fall back to on_llm_start if this is exception is thrown + # Callback handler will fall back to on_llm_start if this exception is thrown msg = f"{self.__class__.__name__} does not implement `on_chat_model_start`" raise NotImplementedError(msg) diff --git a/libs/core/tests/unit_tests/callbacks/test_handle_event.py b/libs/core/tests/unit_tests/callbacks/test_handle_event.py new file mode 100644 index 00000000000..b824717a9dc --- /dev/null +++ b/libs/core/tests/unit_tests/callbacks/test_handle_event.py @@ -0,0 +1,134 @@ +"""Tests for handle_event and _ahandle_event_for_handler fallback behavior. + +Covers the NotImplementedError fallback from on_chat_model_start to on_llm_start. +Handlers must declare `serialized` and `messages` as explicit positional args +(not *args) — see on_chat_model_start docstring for details. + +See: https://github.com/langchain-ai/langchain/issues/31576 +""" + +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from langchain_core.callbacks.base import BaseCallbackHandler +from langchain_core.callbacks.manager import ( + _ahandle_event_for_handler, + handle_event, +) +from langchain_core.messages import BaseMessage, HumanMessage + + +class _FallbackChatHandler(BaseCallbackHandler): + """Handler that correctly declares the required args but raises NotImplementedError. + + This triggers the fallback to on_llm_start, as documented. + """ + + def on_chat_model_start( + self, + serialized: dict[str, Any], + messages: list[list[BaseMessage]], + **kwargs: Any, + ) -> None: + raise NotImplementedError + + def on_llm_start(self, *args: Any, **kwargs: Any) -> None: + pass + + +class _FallbackChatHandlerAsync(BaseCallbackHandler): + """Async-compatible handler; raises NotImplementedError for on_chat_model_start.""" + + run_inline = True + + def on_chat_model_start( + self, + serialized: dict[str, Any], + messages: list[list[BaseMessage]], + **kwargs: Any, + ) -> None: + raise NotImplementedError + + def on_llm_start(self, *args: Any, **kwargs: Any) -> None: + pass + + +def test_handle_event_chat_model_start_fallback_to_llm_start() -> None: + """on_chat_model_start raises NotImplementedError → falls back to on_llm_start.""" + handler = _FallbackChatHandler() + handler.on_llm_start = MagicMock() # type: ignore[method-assign] + + serialized = {"name": "test"} + messages = [[HumanMessage(content="hello")]] + + handle_event( + [handler], + "on_chat_model_start", + "ignore_chat_model", + serialized, + messages, + ) + + handler.on_llm_start.assert_called_once() + + +def test_handle_event_other_event_not_implemented_logs_warning() -> None: + """Non-chat_model_start events that raise NotImplementedError log a warning.""" + + class _Handler(BaseCallbackHandler): + def on_llm_start(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError + + handler = _Handler() + + # Should not raise — logs a warning instead + handle_event( + [handler], + "on_llm_start", + "ignore_llm", + {"name": "test"}, + ["prompt"], + ) + + +@pytest.mark.asyncio +async def test_ahandle_event_chat_model_start_fallback_to_llm_start() -> None: + """Async: on_chat_model_start NotImplementedError falls back to on_llm_start.""" + handler = _FallbackChatHandlerAsync() + handler.on_llm_start = MagicMock() # type: ignore[method-assign] + + serialized = {"name": "test"} + messages = [[HumanMessage(content="hello")]] + + await _ahandle_event_for_handler( + handler, + "on_chat_model_start", + "ignore_chat_model", + serialized, + messages, + ) + + handler.on_llm_start.assert_called_once() + + +@pytest.mark.asyncio +async def test_ahandle_event_other_event_not_implemented_logs_warning() -> None: + """Async: non-chat_model_start events log warning on NotImplementedError.""" + + class _Handler(BaseCallbackHandler): + run_inline = True + + def on_llm_start(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError + + handler = _Handler() + + await _ahandle_event_for_handler( + handler, + "on_llm_start", + "ignore_llm", + {"name": "test"}, + ["prompt"], + )