mirror of
https://github.com/hwchase17/langchain.git
synced 2026-07-01 14:47:02 +00:00
fix(core): async tracer on_chat_model_start fallback in sync context (#35233)
Fixes #30870 When an `AsyncBaseTracer` with `_schema_format="original"` (the default) is used with sync `llm.invoke()`, the `on_chat_model_start` to `on_llm_start` fallback doesn't fire. The async handler returns a coroutine instead of raising `NotImplementedError` synchronously, so it bypasses the existing fallback logic and lands in `_run_coros`, which only logs the error generically. This fallback already works for sync handlers in sync context and async handlers in async context. This PR closes the gap for async handlers in sync context.
This commit is contained in:
@@ -252,6 +252,35 @@ def shielded(func: Func) -> Func:
|
||||
return cast("Func", wrapped)
|
||||
|
||||
|
||||
async def _achat_model_start_fallback(
|
||||
coro: Coroutine[Any, Any, Any],
|
||||
handler: BaseCallbackHandler,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Wrap an async `on_chat_model_start` coroutine with fallback.
|
||||
|
||||
Catches `NotImplementedError` and triggers the `on_llm_start` fallback.
|
||||
This covers async handlers invoked from a **sync** `handle_event` call,
|
||||
where the coroutine is collected into `coros` and executed later by
|
||||
`_run_coros`. Without this wrapper the `NotImplementedError` would be
|
||||
caught generically by `_run_coros` and the trace would be lost.
|
||||
"""
|
||||
try:
|
||||
await coro
|
||||
except NotImplementedError:
|
||||
message_strings = [get_buffer_string(m) for m in args[1]]
|
||||
await _ahandle_event_for_handler(
|
||||
handler,
|
||||
"on_llm_start",
|
||||
"ignore_llm",
|
||||
args[0],
|
||||
message_strings,
|
||||
*args[2:],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def handle_event(
|
||||
handlers: list[BaseCallbackHandler],
|
||||
event_name: str,
|
||||
@@ -281,6 +310,10 @@ def handle_event(
|
||||
):
|
||||
event = getattr(handler, event_name)(*args, **kwargs)
|
||||
if asyncio.iscoroutine(event):
|
||||
if event_name == "on_chat_model_start":
|
||||
event = _achat_model_start_fallback(
|
||||
event, handler, *args, **kwargs
|
||||
)
|
||||
coros.append(event)
|
||||
except NotImplementedError as e:
|
||||
if event_name == "on_chat_model_start":
|
||||
@@ -334,6 +367,11 @@ def handle_event(
|
||||
|
||||
|
||||
def _run_coros(coros: list[Coroutine[Any, Any, Any]]) -> None:
|
||||
# Note: exceptions raised by these coroutines are always logged and swallowed
|
||||
# here, regardless of the handler's `raise_error` setting. Async-handler errors
|
||||
# driven through sync `handle_event` therefore never propagate, unlike errors
|
||||
# from sync handlers (which honor `raise_error`). This is a pre-existing
|
||||
# asymmetry between the sync and async callback paths.
|
||||
if hasattr(asyncio, "Runner"):
|
||||
# Python 3.11+
|
||||
# Run the coroutines in a new event loop, taking care to
|
||||
|
||||
@@ -5,10 +5,14 @@ Handlers must declare `serialized` and `messages` as explicit positional args
|
||||
(not *args) — see on_chat_model_start docstring for details.
|
||||
|
||||
See: https://github.com/langchain-ai/langchain/issues/31576
|
||||
See: https://github.com/langchain-ai/langchain/issues/30870
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -17,7 +21,13 @@ from langchain_core.callbacks.manager import (
|
||||
_ahandle_event_for_handler,
|
||||
handle_event,
|
||||
)
|
||||
from langchain_core.messages import BaseMessage, HumanMessage
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.tracers.base import AsyncBaseTracer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
SERIALIZED = {"id": ["chat_model"]}
|
||||
|
||||
|
||||
class _FallbackChatHandler(BaseCallbackHandler):
|
||||
@@ -55,19 +65,74 @@ class _FallbackChatHandlerAsync(BaseCallbackHandler):
|
||||
pass
|
||||
|
||||
|
||||
class _NoOpAsyncTracer(AsyncBaseTracer):
|
||||
"""Async tracer that does not override on_chat_model_start."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.runs: list[Run] = []
|
||||
self.llm_start_calls: list[dict[str, Any]] = []
|
||||
|
||||
async def _persist_run(self, run: Run) -> None:
|
||||
self.runs.append(run)
|
||||
|
||||
async def on_llm_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
prompts: list[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
**_kwargs: Any,
|
||||
) -> None:
|
||||
self.llm_start_calls.append(
|
||||
{
|
||||
"serialized": serialized,
|
||||
"prompts": prompts,
|
||||
"run_id": run_id,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class _WorkingAsyncTracer(AsyncBaseTracer):
|
||||
"""Async tracer that implements on_chat_model_start."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.runs: list[Run] = []
|
||||
self.chat_model_start_calls: list[dict[str, Any]] = []
|
||||
|
||||
async def _persist_run(self, run: Run) -> None:
|
||||
self.runs.append(run)
|
||||
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[Any]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
**_kwargs: Any,
|
||||
) -> None:
|
||||
self.chat_model_start_calls.append(
|
||||
{
|
||||
"serialized": serialized,
|
||||
"messages": messages,
|
||||
"run_id": run_id,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_handle_event_chat_model_start_fallback_to_llm_start() -> None:
|
||||
"""on_chat_model_start raises NotImplementedError → falls back to on_llm_start."""
|
||||
handler = _FallbackChatHandler()
|
||||
handler.on_llm_start = MagicMock() # type: ignore[method-assign]
|
||||
|
||||
serialized = {"name": "test"}
|
||||
messages = [[HumanMessage(content="hello")]]
|
||||
|
||||
handle_event(
|
||||
[handler],
|
||||
"on_chat_model_start",
|
||||
"ignore_chat_model",
|
||||
serialized,
|
||||
SERIALIZED,
|
||||
messages,
|
||||
)
|
||||
|
||||
@@ -83,7 +148,6 @@ def test_handle_event_other_event_not_implemented_logs_warning() -> None:
|
||||
|
||||
handler = _Handler()
|
||||
|
||||
# Should not raise — logs a warning instead
|
||||
handle_event(
|
||||
[handler],
|
||||
"on_llm_start",
|
||||
@@ -99,14 +163,13 @@ async def test_ahandle_event_chat_model_start_fallback_to_llm_start() -> None:
|
||||
handler = _FallbackChatHandlerAsync()
|
||||
handler.on_llm_start = MagicMock() # type: ignore[method-assign]
|
||||
|
||||
serialized = {"name": "test"}
|
||||
messages = [[HumanMessage(content="hello")]]
|
||||
|
||||
await _ahandle_event_for_handler(
|
||||
handler,
|
||||
"on_chat_model_start",
|
||||
"ignore_chat_model",
|
||||
serialized,
|
||||
SERIALIZED,
|
||||
messages,
|
||||
)
|
||||
|
||||
@@ -132,3 +195,68 @@ async def test_ahandle_event_other_event_not_implemented_logs_warning() -> None:
|
||||
{"name": "test"},
|
||||
["prompt"],
|
||||
)
|
||||
|
||||
|
||||
def test_async_tracer_falls_back_to_on_llm_start_in_sync_context() -> None:
|
||||
"""Async tracer without on_chat_model_start falls back in handle_event."""
|
||||
tracer = _NoOpAsyncTracer()
|
||||
run_id = uuid4()
|
||||
messages = [[SystemMessage(content="sys"), HumanMessage(content="hi")]]
|
||||
|
||||
handle_event(
|
||||
[tracer],
|
||||
"on_chat_model_start",
|
||||
"ignore_chat_model",
|
||||
SERIALIZED,
|
||||
messages,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
assert len(tracer.llm_start_calls) == 1
|
||||
call = tracer.llm_start_calls[0]
|
||||
assert call["serialized"] == SERIALIZED
|
||||
assert isinstance(call["prompts"], list)
|
||||
assert len(call["prompts"]) == 1
|
||||
assert isinstance(call["prompts"][0], str)
|
||||
|
||||
|
||||
def test_async_tracer_no_fallback_when_implemented() -> None:
|
||||
"""Async tracer with on_chat_model_start does not fall back."""
|
||||
tracer = _WorkingAsyncTracer()
|
||||
messages = [[HumanMessage(content="hello")]]
|
||||
|
||||
handle_event(
|
||||
[tracer],
|
||||
"on_chat_model_start",
|
||||
"ignore_chat_model",
|
||||
SERIALIZED,
|
||||
messages,
|
||||
run_id=uuid4(),
|
||||
)
|
||||
|
||||
assert len(tracer.chat_model_start_calls) == 1
|
||||
call = tracer.chat_model_start_calls[0]
|
||||
assert call["serialized"] == SERIALIZED
|
||||
assert call["messages"] == messages
|
||||
|
||||
|
||||
def test_async_tracer_fallback_no_error_logged(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Async tracer fallback path should not produce warning logs."""
|
||||
tracer = _NoOpAsyncTracer()
|
||||
messages = [[HumanMessage(content="test")]]
|
||||
|
||||
with caplog.at_level("WARNING", logger="langchain_core.callbacks.manager"):
|
||||
handle_event(
|
||||
[tracer],
|
||||
"on_chat_model_start",
|
||||
"ignore_chat_model",
|
||||
SERIALIZED,
|
||||
messages,
|
||||
run_id=uuid4(),
|
||||
)
|
||||
|
||||
assert not caplog.records, (
|
||||
f"Expected no warnings but got: {[r.message for r in caplog.records]}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user