core[patch]: Update AsyncCallbackManager to honor run_inline attribute and prevent context loss (#26885)

## Description

This PR fixes the context loss issue in `AsyncCallbackManager`,
specifically in `on_llm_start` and `on_chat_model_start` methods. It
properly honors the `run_inline` attribute of callback handlers,
preventing race conditions and ordering issues.

Key changes:
1. Separate handlers into inline and non-inline groups.
2. Execute inline handlers sequentially for each prompt.
3. Execute non-inline handlers concurrently across all prompts.
4. Preserve context for stateful handlers.
5. Maintain performance benefits for non-inline handlers.

**These changes are implemented in `AsyncCallbackManager` rather than
`ahandle_event` because the issue occurs at the prompt and message_list
levels, not within individual events.**

## Testing

- Test case implemented in #26857 now passes, verifying execution order
for inline handlers.

## Related Issues

- Fixes issue discussed in #23909 

## Dependencies

No new dependencies are required.

---

@eyurtsev: This PR implements the discussed changes to respect
`run_inline` in `AsyncCallbackManager`. Please review and advise on any
needed changes.

Twitter handle: @parambharat

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Bharat Ramanathan 2024-10-08 00:29:29 +05:30 committed by GitHub
parent c61b9daef5
commit 931ce8d026
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 207 additions and 21 deletions

View File

@ -1729,7 +1729,12 @@ class AsyncCallbackManager(BaseCallbackManager):
to each prompt.
"""
tasks = []
inline_tasks = []
non_inline_tasks = []
inline_handlers = [handler for handler in self.handlers if handler.run_inline]
non_inline_handlers = [
handler for handler in self.handlers if not handler.run_inline
]
managers = []
for prompt in prompts:
@ -1739,20 +1744,36 @@ class AsyncCallbackManager(BaseCallbackManager):
else:
run_id_ = uuid.uuid4()
tasks.append(
ahandle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
[prompt],
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
metadata=self.metadata,
**kwargs,
if inline_handlers:
inline_tasks.append(
ahandle_event(
inline_handlers,
"on_llm_start",
"ignore_llm",
serialized,
[prompt],
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
metadata=self.metadata,
**kwargs,
)
)
else:
non_inline_tasks.append(
ahandle_event(
non_inline_handlers,
"on_llm_start",
"ignore_llm",
serialized,
[prompt],
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
metadata=self.metadata,
**kwargs,
)
)
)
managers.append(
AsyncCallbackManagerForLLMRun(
@ -1767,7 +1788,13 @@ class AsyncCallbackManager(BaseCallbackManager):
)
)
await asyncio.gather(*tasks)
# Run inline tasks sequentially
for inline_task in inline_tasks:
await inline_task
# Run non-inline tasks concurrently
if non_inline_tasks:
await asyncio.gather(*non_inline_tasks)
return managers
@ -1791,7 +1818,8 @@ class AsyncCallbackManager(BaseCallbackManager):
async callback managers, one for each LLM Run
corresponding to each inner message list.
"""
tasks = []
inline_tasks = []
non_inline_tasks = []
managers = []
for message_list in messages:
@ -1801,9 +1829,9 @@ class AsyncCallbackManager(BaseCallbackManager):
else:
run_id_ = uuid.uuid4()
tasks.append(
ahandle_event(
self.handlers,
for handler in self.handlers:
task = ahandle_event(
[handler],
"on_chat_model_start",
"ignore_chat_model",
serialized,
@ -1814,7 +1842,10 @@ class AsyncCallbackManager(BaseCallbackManager):
metadata=self.metadata,
**kwargs,
)
)
if handler.run_inline:
inline_tasks.append(task)
else:
non_inline_tasks.append(task)
managers.append(
AsyncCallbackManagerForLLMRun(
@ -1829,7 +1860,14 @@ class AsyncCallbackManager(BaseCallbackManager):
)
)
await asyncio.gather(*tasks)
# Run inline tasks sequentially
for task in inline_tasks:
await task
# Run non-inline tasks concurrently
if non_inline_tasks:
await asyncio.gather(*non_inline_tasks)
return managers
async def on_chain_start(

View File

@ -0,0 +1,148 @@
"""Unit tests for verifying event dispatching.
Much of this code is indirectly tested already through many end-to-end tests
that generate traces based on the callbacks. The traces are all verified
via snapshot testing (e.g., see unit tests for runnables).
"""
import contextvars
from contextlib import asynccontextmanager
from typing import Any, Optional
from uuid import UUID
from langchain_core.callbacks import (
AsyncCallbackHandler,
AsyncCallbackManager,
BaseCallbackHandler,
)
async def test_inline_handlers_share_parent_context() -> None:
"""Verify that handlers that are configured to run_inline can update parent context.
This test was created because some of the inline handlers were getting
their own context as the handling logic was kicked off using asyncio.gather
which does not automatically propagate the parent context (by design).
This issue was affecting only a few specific handlers:
* on_llm_start
* on_chat_model_start
which in some cases were triggered with multiple prompts and as a result
triggering multiple tasks that were launched in parallel.
"""
some_var: contextvars.ContextVar[str] = contextvars.ContextVar("some_var")
class CustomHandler(AsyncCallbackHandler):
"""A handler that sets the context variable.
The handler sets the context variable to the name of the callback that was
called.
"""
def __init__(self, run_inline: bool) -> None:
"""Initialize the handler."""
self.run_inline = run_inline
async def on_llm_start(self, *args: Any, **kwargs: Any) -> None:
"""Update the callstack with the name of the callback."""
some_var.set("on_llm_start")
# The manager serves as a callback dispatcher.
# It's responsible for dispatching callbacks to all registered handlers.
manager = AsyncCallbackManager(handlers=[CustomHandler(run_inline=True)])
# Check on_llm_start
some_var.set("unset")
await manager.on_llm_start({}, ["prompt 1"])
assert some_var.get() == "on_llm_start"
# Check what happens when run_inline is False
# We don't expect the context to be updated
manager2 = AsyncCallbackManager(
handlers=[
CustomHandler(run_inline=False),
]
)
some_var.set("unset")
await manager2.on_llm_start({}, ["prompt 1"])
# Will not be updated because the handler is not inline
assert some_var.get() == "unset"
async def test_inline_handlers_share_parent_context_multiple() -> None:
"""A slightly more complex variation of the test unit test above.
This unit test verifies that things work correctly when there are multiple prompts,
and multiple handlers that are configured to run inline.
"""
counter_var = contextvars.ContextVar("counter", default=0)
shared_stack = []
@asynccontextmanager
async def set_counter_var() -> Any:
token = counter_var.set(0)
try:
yield
finally:
counter_var.reset(token)
class StatefulAsyncCallbackHandler(AsyncCallbackHandler):
def __init__(self, name: str, run_inline: bool = True):
self.name = name
self.run_inline = run_inline
async def on_llm_start(
self,
serialized: dict[str, Any],
prompts: list[str],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
if self.name == "StateModifier":
current_counter = counter_var.get()
counter_var.set(current_counter + 1)
state = counter_var.get()
elif self.name == "StateReader":
state = counter_var.get()
else:
state = None
shared_stack.append(state)
await super().on_llm_start(
serialized,
prompts,
run_id=run_id,
parent_run_id=parent_run_id,
**kwargs,
)
handlers: list[BaseCallbackHandler] = [
StatefulAsyncCallbackHandler("StateModifier", run_inline=True),
StatefulAsyncCallbackHandler("StateReader", run_inline=True),
StatefulAsyncCallbackHandler("NonInlineHandler", run_inline=False),
]
prompts = ["Prompt1", "Prompt2", "Prompt3"]
async with set_counter_var():
shared_stack.clear()
manager = AsyncCallbackManager(handlers=handlers)
await manager.on_llm_start({}, prompts)
# Assert the order of states
states = [entry for entry in shared_stack if entry is not None]
assert states == [
1,
1,
2,
2,
3,
3,
], f"Expected order of states was broken due to context loss. Got {states}"