fix(core): async tracer on_chat_model_start fallback in sync context (#35233)

Fixes #30870

When an `AsyncBaseTracer` with `_schema_format="original"` (the default)
is used with sync `llm.invoke()`, the `on_chat_model_start` to
`on_llm_start` fallback doesn't fire. The async handler returns a
coroutine instead of raising `NotImplementedError` synchronously, so it
bypasses the existing fallback logic and lands in `_run_coros`, which
only logs the error generically.

This fallback already works for sync handlers in sync context and async
handlers in async context. This PR closes the gap for async handlers in
sync context.
This commit is contained in:
Mason Daugherty
2026-06-10 22:15:29 -04:00
committed by GitHub
parent 8fc58c6013
commit 7cc9d0c84d
2 changed files with 173 additions and 7 deletions

View File

@@ -252,6 +252,35 @@ def shielded(func: Func) -> Func:
return cast("Func", wrapped)
async def _achat_model_start_fallback(
coro: Coroutine[Any, Any, Any],
handler: BaseCallbackHandler,
*args: Any,
**kwargs: Any,
) -> None:
"""Wrap an async `on_chat_model_start` coroutine with fallback.
Catches `NotImplementedError` and triggers the `on_llm_start` fallback.
This covers async handlers invoked from a **sync** `handle_event` call,
where the coroutine is collected into `coros` and executed later by
`_run_coros`. Without this wrapper the `NotImplementedError` would be
caught generically by `_run_coros` and the trace would be lost.
"""
try:
await coro
except NotImplementedError:
message_strings = [get_buffer_string(m) for m in args[1]]
await _ahandle_event_for_handler(
handler,
"on_llm_start",
"ignore_llm",
args[0],
message_strings,
*args[2:],
**kwargs,
)
def handle_event(
handlers: list[BaseCallbackHandler],
event_name: str,
@@ -281,6 +310,10 @@ def handle_event(
):
event = getattr(handler, event_name)(*args, **kwargs)
if asyncio.iscoroutine(event):
if event_name == "on_chat_model_start":
event = _achat_model_start_fallback(
event, handler, *args, **kwargs
)
coros.append(event)
except NotImplementedError as e:
if event_name == "on_chat_model_start":
@@ -334,6 +367,11 @@ def handle_event(
def _run_coros(coros: list[Coroutine[Any, Any, Any]]) -> None:
# Note: exceptions raised by these coroutines are always logged and swallowed
# here, regardless of the handler's `raise_error` setting. Async-handler errors
# driven through sync `handle_event` therefore never propagate, unlike errors
# from sync handlers (which honor `raise_error`). This is a pre-existing
# asymmetry between the sync and async callback paths.
if hasattr(asyncio, "Runner"):
# Python 3.11+
# Run the coroutines in a new event loop, taking care to

View File

@@ -5,10 +5,14 @@ Handlers must declare `serialized` and `messages` as explicit positional args
(not *args) — see on_chat_model_start docstring for details.
See: https://github.com/langchain-ai/langchain/issues/31576
See: https://github.com/langchain-ai/langchain/issues/30870
"""
from typing import Any
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from unittest.mock import MagicMock
from uuid import UUID, uuid4
import pytest
@@ -17,7 +21,13 @@ from langchain_core.callbacks.manager import (
_ahandle_event_for_handler,
handle_event,
)
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.tracers.base import AsyncBaseTracer
if TYPE_CHECKING:
from langchain_core.tracers.schemas import Run
SERIALIZED = {"id": ["chat_model"]}
class _FallbackChatHandler(BaseCallbackHandler):
@@ -55,19 +65,74 @@ class _FallbackChatHandlerAsync(BaseCallbackHandler):
pass
class _NoOpAsyncTracer(AsyncBaseTracer):
"""Async tracer that does not override on_chat_model_start."""
def __init__(self) -> None:
super().__init__()
self.runs: list[Run] = []
self.llm_start_calls: list[dict[str, Any]] = []
async def _persist_run(self, run: Run) -> None:
self.runs.append(run)
async def on_llm_start(
self,
serialized: dict[str, Any],
prompts: list[str],
*,
run_id: UUID,
**_kwargs: Any,
) -> None:
self.llm_start_calls.append(
{
"serialized": serialized,
"prompts": prompts,
"run_id": run_id,
}
)
class _WorkingAsyncTracer(AsyncBaseTracer):
"""Async tracer that implements on_chat_model_start."""
def __init__(self) -> None:
super().__init__()
self.runs: list[Run] = []
self.chat_model_start_calls: list[dict[str, Any]] = []
async def _persist_run(self, run: Run) -> None:
self.runs.append(run)
async def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[Any]],
*,
run_id: UUID,
**_kwargs: Any,
) -> None:
self.chat_model_start_calls.append(
{
"serialized": serialized,
"messages": messages,
"run_id": run_id,
}
)
def test_handle_event_chat_model_start_fallback_to_llm_start() -> None:
"""on_chat_model_start raises NotImplementedError → falls back to on_llm_start."""
handler = _FallbackChatHandler()
handler.on_llm_start = MagicMock() # type: ignore[method-assign]
serialized = {"name": "test"}
messages = [[HumanMessage(content="hello")]]
handle_event(
[handler],
"on_chat_model_start",
"ignore_chat_model",
serialized,
SERIALIZED,
messages,
)
@@ -83,7 +148,6 @@ def test_handle_event_other_event_not_implemented_logs_warning() -> None:
handler = _Handler()
# Should not raise — logs a warning instead
handle_event(
[handler],
"on_llm_start",
@@ -99,14 +163,13 @@ async def test_ahandle_event_chat_model_start_fallback_to_llm_start() -> None:
handler = _FallbackChatHandlerAsync()
handler.on_llm_start = MagicMock() # type: ignore[method-assign]
serialized = {"name": "test"}
messages = [[HumanMessage(content="hello")]]
await _ahandle_event_for_handler(
handler,
"on_chat_model_start",
"ignore_chat_model",
serialized,
SERIALIZED,
messages,
)
@@ -132,3 +195,68 @@ async def test_ahandle_event_other_event_not_implemented_logs_warning() -> None:
{"name": "test"},
["prompt"],
)
def test_async_tracer_falls_back_to_on_llm_start_in_sync_context() -> None:
"""Async tracer without on_chat_model_start falls back in handle_event."""
tracer = _NoOpAsyncTracer()
run_id = uuid4()
messages = [[SystemMessage(content="sys"), HumanMessage(content="hi")]]
handle_event(
[tracer],
"on_chat_model_start",
"ignore_chat_model",
SERIALIZED,
messages,
run_id=run_id,
)
assert len(tracer.llm_start_calls) == 1
call = tracer.llm_start_calls[0]
assert call["serialized"] == SERIALIZED
assert isinstance(call["prompts"], list)
assert len(call["prompts"]) == 1
assert isinstance(call["prompts"][0], str)
def test_async_tracer_no_fallback_when_implemented() -> None:
"""Async tracer with on_chat_model_start does not fall back."""
tracer = _WorkingAsyncTracer()
messages = [[HumanMessage(content="hello")]]
handle_event(
[tracer],
"on_chat_model_start",
"ignore_chat_model",
SERIALIZED,
messages,
run_id=uuid4(),
)
assert len(tracer.chat_model_start_calls) == 1
call = tracer.chat_model_start_calls[0]
assert call["serialized"] == SERIALIZED
assert call["messages"] == messages
def test_async_tracer_fallback_no_error_logged(
caplog: pytest.LogCaptureFixture,
) -> None:
"""Async tracer fallback path should not produce warning logs."""
tracer = _NoOpAsyncTracer()
messages = [[HumanMessage(content="test")]]
with caplog.at_level("WARNING", logger="langchain_core.callbacks.manager"):
handle_event(
[tracer],
"on_chat_model_start",
"ignore_chat_model",
SERIALIZED,
messages,
run_id=uuid4(),
)
assert not caplog.records, (
f"Expected no warnings but got: {[r.message for r in caplog.records]}"
)