diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 8d46f74eea2..f5f7d504492 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -252,6 +252,35 @@ def shielded(func: Func) -> Func: return cast("Func", wrapped) +async def _achat_model_start_fallback( + coro: Coroutine[Any, Any, Any], + handler: BaseCallbackHandler, + *args: Any, + **kwargs: Any, +) -> None: + """Wrap an async `on_chat_model_start` coroutine with fallback. + + Catches `NotImplementedError` and triggers the `on_llm_start` fallback. + This covers async handlers invoked from a **sync** `handle_event` call, + where the coroutine is collected into `coros` and executed later by + `_run_coros`. Without this wrapper the `NotImplementedError` would be + caught generically by `_run_coros` and the trace would be lost. + """ + try: + await coro + except NotImplementedError: + message_strings = [get_buffer_string(m) for m in args[1]] + await _ahandle_event_for_handler( + handler, + "on_llm_start", + "ignore_llm", + args[0], + message_strings, + *args[2:], + **kwargs, + ) + + def handle_event( handlers: list[BaseCallbackHandler], event_name: str, @@ -281,6 +310,10 @@ def handle_event( ): event = getattr(handler, event_name)(*args, **kwargs) if asyncio.iscoroutine(event): + if event_name == "on_chat_model_start": + event = _achat_model_start_fallback( + event, handler, *args, **kwargs + ) coros.append(event) except NotImplementedError as e: if event_name == "on_chat_model_start": @@ -334,6 +367,11 @@ def handle_event( def _run_coros(coros: list[Coroutine[Any, Any, Any]]) -> None: + # Note: exceptions raised by these coroutines are always logged and swallowed + # here, regardless of the handler's `raise_error` setting. Async-handler errors + # driven through sync `handle_event` therefore never propagate, unlike errors + # from sync handlers (which honor `raise_error`). This is a pre-existing + # asymmetry between the sync and async callback paths. if hasattr(asyncio, "Runner"): # Python 3.11+ # Run the coroutines in a new event loop, taking care to diff --git a/libs/core/tests/unit_tests/callbacks/test_handle_event.py b/libs/core/tests/unit_tests/callbacks/test_handle_event.py index b824717a9dc..bcda51af3eb 100644 --- a/libs/core/tests/unit_tests/callbacks/test_handle_event.py +++ b/libs/core/tests/unit_tests/callbacks/test_handle_event.py @@ -5,10 +5,14 @@ Handlers must declare `serialized` and `messages` as explicit positional args (not *args) — see on_chat_model_start docstring for details. See: https://github.com/langchain-ai/langchain/issues/31576 +See: https://github.com/langchain-ai/langchain/issues/30870 """ -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Any from unittest.mock import MagicMock +from uuid import UUID, uuid4 import pytest @@ -17,7 +21,13 @@ from langchain_core.callbacks.manager import ( _ahandle_event_for_handler, handle_event, ) -from langchain_core.messages import BaseMessage, HumanMessage +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage +from langchain_core.tracers.base import AsyncBaseTracer + +if TYPE_CHECKING: + from langchain_core.tracers.schemas import Run + +SERIALIZED = {"id": ["chat_model"]} class _FallbackChatHandler(BaseCallbackHandler): @@ -55,19 +65,74 @@ class _FallbackChatHandlerAsync(BaseCallbackHandler): pass +class _NoOpAsyncTracer(AsyncBaseTracer): + """Async tracer that does not override on_chat_model_start.""" + + def __init__(self) -> None: + super().__init__() + self.runs: list[Run] = [] + self.llm_start_calls: list[dict[str, Any]] = [] + + async def _persist_run(self, run: Run) -> None: + self.runs.append(run) + + async def on_llm_start( + self, + serialized: dict[str, Any], + prompts: list[str], + *, + run_id: UUID, + **_kwargs: Any, + ) -> None: + self.llm_start_calls.append( + { + "serialized": serialized, + "prompts": prompts, + "run_id": run_id, + } + ) + + +class _WorkingAsyncTracer(AsyncBaseTracer): + """Async tracer that implements on_chat_model_start.""" + + def __init__(self) -> None: + super().__init__() + self.runs: list[Run] = [] + self.chat_model_start_calls: list[dict[str, Any]] = [] + + async def _persist_run(self, run: Run) -> None: + self.runs.append(run) + + async def on_chat_model_start( + self, + serialized: dict[str, Any], + messages: list[list[Any]], + *, + run_id: UUID, + **_kwargs: Any, + ) -> None: + self.chat_model_start_calls.append( + { + "serialized": serialized, + "messages": messages, + "run_id": run_id, + } + ) + + def test_handle_event_chat_model_start_fallback_to_llm_start() -> None: """on_chat_model_start raises NotImplementedError → falls back to on_llm_start.""" handler = _FallbackChatHandler() handler.on_llm_start = MagicMock() # type: ignore[method-assign] - serialized = {"name": "test"} messages = [[HumanMessage(content="hello")]] handle_event( [handler], "on_chat_model_start", "ignore_chat_model", - serialized, + SERIALIZED, messages, ) @@ -83,7 +148,6 @@ def test_handle_event_other_event_not_implemented_logs_warning() -> None: handler = _Handler() - # Should not raise — logs a warning instead handle_event( [handler], "on_llm_start", @@ -99,14 +163,13 @@ async def test_ahandle_event_chat_model_start_fallback_to_llm_start() -> None: handler = _FallbackChatHandlerAsync() handler.on_llm_start = MagicMock() # type: ignore[method-assign] - serialized = {"name": "test"} messages = [[HumanMessage(content="hello")]] await _ahandle_event_for_handler( handler, "on_chat_model_start", "ignore_chat_model", - serialized, + SERIALIZED, messages, ) @@ -132,3 +195,68 @@ async def test_ahandle_event_other_event_not_implemented_logs_warning() -> None: {"name": "test"}, ["prompt"], ) + + +def test_async_tracer_falls_back_to_on_llm_start_in_sync_context() -> None: + """Async tracer without on_chat_model_start falls back in handle_event.""" + tracer = _NoOpAsyncTracer() + run_id = uuid4() + messages = [[SystemMessage(content="sys"), HumanMessage(content="hi")]] + + handle_event( + [tracer], + "on_chat_model_start", + "ignore_chat_model", + SERIALIZED, + messages, + run_id=run_id, + ) + + assert len(tracer.llm_start_calls) == 1 + call = tracer.llm_start_calls[0] + assert call["serialized"] == SERIALIZED + assert isinstance(call["prompts"], list) + assert len(call["prompts"]) == 1 + assert isinstance(call["prompts"][0], str) + + +def test_async_tracer_no_fallback_when_implemented() -> None: + """Async tracer with on_chat_model_start does not fall back.""" + tracer = _WorkingAsyncTracer() + messages = [[HumanMessage(content="hello")]] + + handle_event( + [tracer], + "on_chat_model_start", + "ignore_chat_model", + SERIALIZED, + messages, + run_id=uuid4(), + ) + + assert len(tracer.chat_model_start_calls) == 1 + call = tracer.chat_model_start_calls[0] + assert call["serialized"] == SERIALIZED + assert call["messages"] == messages + + +def test_async_tracer_fallback_no_error_logged( + caplog: pytest.LogCaptureFixture, +) -> None: + """Async tracer fallback path should not produce warning logs.""" + tracer = _NoOpAsyncTracer() + messages = [[HumanMessage(content="test")]] + + with caplog.at_level("WARNING", logger="langchain_core.callbacks.manager"): + handle_event( + [tracer], + "on_chat_model_start", + "ignore_chat_model", + SERIALIZED, + messages, + run_id=uuid4(), + ) + + assert not caplog.records, ( + f"Expected no warnings but got: {[r.message for r in caplog.records]}" + )