fix(core): improve typing/docs for on_chat_model_start to clarify required positional args (#35324)

This commit is contained in:
Balaji Seshadri
2026-02-22 14:46:32 -05:00
committed by GitHub
parent 875c3c573d
commit d6e46bb4b0
2 changed files with 178 additions and 4 deletions

View File

@@ -285,9 +285,29 @@ class CallbackManagerMixin:
This method is called for chat models. If you're implementing a handler for
a non-chat model, you should use `on_llm_start` instead.
!!! note
When overriding this method, the signature **must** include the two
required positional arguments ``serialized`` and ``messages``. Avoid
using ``*args`` in your override — doing so causes an ``IndexError``
in the fallback path when the callback system converts ``messages``
to prompt strings for ``on_llm_start``. Always declare the
signature explicitly:
.. code-block:: python
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
**kwargs: Any,
) -> None:
raise NotImplementedError # triggers fallback to on_llm_start
Args:
serialized: The serialized chat model.
messages: The messages.
messages: The messages. Must be a list of message lists — this is a
required positional argument and must be present in any override.
run_id: The ID of the current run.
parent_run_id: The ID of the parent run.
tags: The tags.
@@ -295,7 +315,7 @@ class CallbackManagerMixin:
**kwargs: Additional keyword arguments.
"""
# NotImplementedError is thrown intentionally
# Callback handler will fall back to on_llm_start if this is exception is thrown
# Callback handler will fall back to on_llm_start if this exception is thrown
msg = f"{self.__class__.__name__} does not implement `on_chat_model_start`"
raise NotImplementedError(msg)
@@ -534,9 +554,29 @@ class AsyncCallbackHandler(BaseCallbackHandler):
This method is called for chat models. If you're implementing a handler for
a non-chat model, you should use `on_llm_start` instead.
!!! note
When overriding this method, the signature **must** include the two
required positional arguments ``serialized`` and ``messages``. Avoid
using ``*args`` in your override — doing so causes an ``IndexError``
in the fallback path when the callback system converts ``messages``
to prompt strings for ``on_llm_start``. Always declare the
signature explicitly:
.. code-block:: python
async def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
**kwargs: Any,
) -> None:
raise NotImplementedError # triggers fallback to on_llm_start
Args:
serialized: The serialized chat model.
messages: The messages.
messages: The messages. Must be a list of message lists — this is a
required positional argument and must be present in any override.
run_id: The ID of the current run.
parent_run_id: The ID of the parent run.
tags: The tags.
@@ -544,7 +584,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
**kwargs: Additional keyword arguments.
"""
# NotImplementedError is thrown intentionally
# Callback handler will fall back to on_llm_start if this is exception is thrown
# Callback handler will fall back to on_llm_start if this exception is thrown
msg = f"{self.__class__.__name__} does not implement `on_chat_model_start`"
raise NotImplementedError(msg)

View File

@@ -0,0 +1,134 @@
"""Tests for handle_event and _ahandle_event_for_handler fallback behavior.
Covers the NotImplementedError fallback from on_chat_model_start to on_llm_start.
Handlers must declare `serialized` and `messages` as explicit positional args
(not *args) — see on_chat_model_start docstring for details.
See: https://github.com/langchain-ai/langchain/issues/31576
"""
from typing import Any
from unittest.mock import MagicMock
import pytest
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.callbacks.manager import (
_ahandle_event_for_handler,
handle_event,
)
from langchain_core.messages import BaseMessage, HumanMessage
class _FallbackChatHandler(BaseCallbackHandler):
"""Handler that correctly declares the required args but raises NotImplementedError.
This triggers the fallback to on_llm_start, as documented.
"""
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
**kwargs: Any,
) -> None:
raise NotImplementedError
def on_llm_start(self, *args: Any, **kwargs: Any) -> None:
pass
class _FallbackChatHandlerAsync(BaseCallbackHandler):
"""Async-compatible handler; raises NotImplementedError for on_chat_model_start."""
run_inline = True
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
**kwargs: Any,
) -> None:
raise NotImplementedError
def on_llm_start(self, *args: Any, **kwargs: Any) -> None:
pass
def test_handle_event_chat_model_start_fallback_to_llm_start() -> None:
"""on_chat_model_start raises NotImplementedError → falls back to on_llm_start."""
handler = _FallbackChatHandler()
handler.on_llm_start = MagicMock() # type: ignore[method-assign]
serialized = {"name": "test"}
messages = [[HumanMessage(content="hello")]]
handle_event(
[handler],
"on_chat_model_start",
"ignore_chat_model",
serialized,
messages,
)
handler.on_llm_start.assert_called_once()
def test_handle_event_other_event_not_implemented_logs_warning() -> None:
"""Non-chat_model_start events that raise NotImplementedError log a warning."""
class _Handler(BaseCallbackHandler):
def on_llm_start(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError
handler = _Handler()
# Should not raise — logs a warning instead
handle_event(
[handler],
"on_llm_start",
"ignore_llm",
{"name": "test"},
["prompt"],
)
@pytest.mark.asyncio
async def test_ahandle_event_chat_model_start_fallback_to_llm_start() -> None:
"""Async: on_chat_model_start NotImplementedError falls back to on_llm_start."""
handler = _FallbackChatHandlerAsync()
handler.on_llm_start = MagicMock() # type: ignore[method-assign]
serialized = {"name": "test"}
messages = [[HumanMessage(content="hello")]]
await _ahandle_event_for_handler(
handler,
"on_chat_model_start",
"ignore_chat_model",
serialized,
messages,
)
handler.on_llm_start.assert_called_once()
@pytest.mark.asyncio
async def test_ahandle_event_other_event_not_implemented_logs_warning() -> None:
"""Async: non-chat_model_start events log warning on NotImplementedError."""
class _Handler(BaseCallbackHandler):
run_inline = True
def on_llm_start(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError
handler = _Handler()
await _ahandle_event_for_handler(
handler,
"on_llm_start",
"ignore_llm",
{"name": "test"},
["prompt"],
)