mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 11:07:36 +00:00
fix(core): improve typing/docs for on_chat_model_start to clarify required positional args (#35324)
This commit is contained in:
@@ -285,9 +285,29 @@ class CallbackManagerMixin:
|
||||
This method is called for chat models. If you're implementing a handler for
|
||||
a non-chat model, you should use `on_llm_start` instead.
|
||||
|
||||
!!! note
|
||||
|
||||
When overriding this method, the signature **must** include the two
|
||||
required positional arguments ``serialized`` and ``messages``. Avoid
|
||||
using ``*args`` in your override — doing so causes an ``IndexError``
|
||||
in the fallback path when the callback system converts ``messages``
|
||||
to prompt strings for ``on_llm_start``. Always declare the
|
||||
signature explicitly:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
raise NotImplementedError # triggers fallback to on_llm_start
|
||||
|
||||
Args:
|
||||
serialized: The serialized chat model.
|
||||
messages: The messages.
|
||||
messages: The messages. Must be a list of message lists — this is a
|
||||
required positional argument and must be present in any override.
|
||||
run_id: The ID of the current run.
|
||||
parent_run_id: The ID of the parent run.
|
||||
tags: The tags.
|
||||
@@ -295,7 +315,7 @@ class CallbackManagerMixin:
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
# NotImplementedError is thrown intentionally
|
||||
# Callback handler will fall back to on_llm_start if this is exception is thrown
|
||||
# Callback handler will fall back to on_llm_start if this exception is thrown
|
||||
msg = f"{self.__class__.__name__} does not implement `on_chat_model_start`"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
@@ -534,9 +554,29 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
This method is called for chat models. If you're implementing a handler for
|
||||
a non-chat model, you should use `on_llm_start` instead.
|
||||
|
||||
!!! note
|
||||
|
||||
When overriding this method, the signature **must** include the two
|
||||
required positional arguments ``serialized`` and ``messages``. Avoid
|
||||
using ``*args`` in your override — doing so causes an ``IndexError``
|
||||
in the fallback path when the callback system converts ``messages``
|
||||
to prompt strings for ``on_llm_start``. Always declare the
|
||||
signature explicitly:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
raise NotImplementedError # triggers fallback to on_llm_start
|
||||
|
||||
Args:
|
||||
serialized: The serialized chat model.
|
||||
messages: The messages.
|
||||
messages: The messages. Must be a list of message lists — this is a
|
||||
required positional argument and must be present in any override.
|
||||
run_id: The ID of the current run.
|
||||
parent_run_id: The ID of the parent run.
|
||||
tags: The tags.
|
||||
@@ -544,7 +584,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
# NotImplementedError is thrown intentionally
|
||||
# Callback handler will fall back to on_llm_start if this is exception is thrown
|
||||
# Callback handler will fall back to on_llm_start if this exception is thrown
|
||||
msg = f"{self.__class__.__name__} does not implement `on_chat_model_start`"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
134
libs/core/tests/unit_tests/callbacks/test_handle_event.py
Normal file
134
libs/core/tests/unit_tests/callbacks/test_handle_event.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Tests for handle_event and _ahandle_event_for_handler fallback behavior.
|
||||
|
||||
Covers the NotImplementedError fallback from on_chat_model_start to on_llm_start.
|
||||
Handlers must declare `serialized` and `messages` as explicit positional args
|
||||
(not *args) — see on_chat_model_start docstring for details.
|
||||
|
||||
See: https://github.com/langchain-ai/langchain/issues/31576
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||
from langchain_core.callbacks.manager import (
|
||||
_ahandle_event_for_handler,
|
||||
handle_event,
|
||||
)
|
||||
from langchain_core.messages import BaseMessage, HumanMessage
|
||||
|
||||
|
||||
class _FallbackChatHandler(BaseCallbackHandler):
|
||||
"""Handler that correctly declares the required args but raises NotImplementedError.
|
||||
|
||||
This triggers the fallback to on_llm_start, as documented.
|
||||
"""
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def on_llm_start(self, *args: Any, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class _FallbackChatHandlerAsync(BaseCallbackHandler):
|
||||
"""Async-compatible handler; raises NotImplementedError for on_chat_model_start."""
|
||||
|
||||
run_inline = True
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def on_llm_start(self, *args: Any, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def test_handle_event_chat_model_start_fallback_to_llm_start() -> None:
|
||||
"""on_chat_model_start raises NotImplementedError → falls back to on_llm_start."""
|
||||
handler = _FallbackChatHandler()
|
||||
handler.on_llm_start = MagicMock() # type: ignore[method-assign]
|
||||
|
||||
serialized = {"name": "test"}
|
||||
messages = [[HumanMessage(content="hello")]]
|
||||
|
||||
handle_event(
|
||||
[handler],
|
||||
"on_chat_model_start",
|
||||
"ignore_chat_model",
|
||||
serialized,
|
||||
messages,
|
||||
)
|
||||
|
||||
handler.on_llm_start.assert_called_once()
|
||||
|
||||
|
||||
def test_handle_event_other_event_not_implemented_logs_warning() -> None:
|
||||
"""Non-chat_model_start events that raise NotImplementedError log a warning."""
|
||||
|
||||
class _Handler(BaseCallbackHandler):
|
||||
def on_llm_start(self, *args: Any, **kwargs: Any) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
handler = _Handler()
|
||||
|
||||
# Should not raise — logs a warning instead
|
||||
handle_event(
|
||||
[handler],
|
||||
"on_llm_start",
|
||||
"ignore_llm",
|
||||
{"name": "test"},
|
||||
["prompt"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ahandle_event_chat_model_start_fallback_to_llm_start() -> None:
|
||||
"""Async: on_chat_model_start NotImplementedError falls back to on_llm_start."""
|
||||
handler = _FallbackChatHandlerAsync()
|
||||
handler.on_llm_start = MagicMock() # type: ignore[method-assign]
|
||||
|
||||
serialized = {"name": "test"}
|
||||
messages = [[HumanMessage(content="hello")]]
|
||||
|
||||
await _ahandle_event_for_handler(
|
||||
handler,
|
||||
"on_chat_model_start",
|
||||
"ignore_chat_model",
|
||||
serialized,
|
||||
messages,
|
||||
)
|
||||
|
||||
handler.on_llm_start.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ahandle_event_other_event_not_implemented_logs_warning() -> None:
|
||||
"""Async: non-chat_model_start events log warning on NotImplementedError."""
|
||||
|
||||
class _Handler(BaseCallbackHandler):
|
||||
run_inline = True
|
||||
|
||||
def on_llm_start(self, *args: Any, **kwargs: Any) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
handler = _Handler()
|
||||
|
||||
await _ahandle_event_for_handler(
|
||||
handler,
|
||||
"on_llm_start",
|
||||
"ignore_llm",
|
||||
{"name": "test"},
|
||||
["prompt"],
|
||||
)
|
||||
Reference in New Issue
Block a user