mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-08 12:31:49 +00:00
core[patch]: Update AsyncCallbackManager
to honor run_inline
attribute and prevent context loss (#26885)
## Description This PR fixes the context loss issue in `AsyncCallbackManager`, specifically in `on_llm_start` and `on_chat_model_start` methods. It properly honors the `run_inline` attribute of callback handlers, preventing race conditions and ordering issues. Key changes: 1. Separate handlers into inline and non-inline groups. 2. Execute inline handlers sequentially for each prompt. 3. Execute non-inline handlers concurrently across all prompts. 4. Preserve context for stateful handlers. 5. Maintain performance benefits for non-inline handlers. **These changes are implemented in `AsyncCallbackManager` rather than `ahandle_event` because the issue occurs at the prompt and message_list levels, not within individual events.** ## Testing - Test case implemented in #26857 now passes, verifying execution order for inline handlers. ## Related Issues - Fixes issue discussed in #23909 ## Dependencies No new dependencies are required. --- @eyurtsev: This PR implements the discussed changes to respect `run_inline` in `AsyncCallbackManager`. Please review and advise on any needed changes. Twitter handle: @parambharat --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
c61b9daef5
commit
931ce8d026
@ -1729,7 +1729,12 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
to each prompt.
|
to each prompt.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tasks = []
|
inline_tasks = []
|
||||||
|
non_inline_tasks = []
|
||||||
|
inline_handlers = [handler for handler in self.handlers if handler.run_inline]
|
||||||
|
non_inline_handlers = [
|
||||||
|
handler for handler in self.handlers if not handler.run_inline
|
||||||
|
]
|
||||||
managers = []
|
managers = []
|
||||||
|
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
@ -1739,20 +1744,36 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
else:
|
else:
|
||||||
run_id_ = uuid.uuid4()
|
run_id_ = uuid.uuid4()
|
||||||
|
|
||||||
tasks.append(
|
if inline_handlers:
|
||||||
ahandle_event(
|
inline_tasks.append(
|
||||||
self.handlers,
|
ahandle_event(
|
||||||
"on_llm_start",
|
inline_handlers,
|
||||||
"ignore_llm",
|
"on_llm_start",
|
||||||
serialized,
|
"ignore_llm",
|
||||||
[prompt],
|
serialized,
|
||||||
run_id=run_id_,
|
[prompt],
|
||||||
parent_run_id=self.parent_run_id,
|
run_id=run_id_,
|
||||||
tags=self.tags,
|
parent_run_id=self.parent_run_id,
|
||||||
metadata=self.metadata,
|
tags=self.tags,
|
||||||
**kwargs,
|
metadata=self.metadata,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
non_inline_tasks.append(
|
||||||
|
ahandle_event(
|
||||||
|
non_inline_handlers,
|
||||||
|
"on_llm_start",
|
||||||
|
"ignore_llm",
|
||||||
|
serialized,
|
||||||
|
[prompt],
|
||||||
|
run_id=run_id_,
|
||||||
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
|
metadata=self.metadata,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
managers.append(
|
managers.append(
|
||||||
AsyncCallbackManagerForLLMRun(
|
AsyncCallbackManagerForLLMRun(
|
||||||
@ -1767,7 +1788,13 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
await asyncio.gather(*tasks)
|
# Run inline tasks sequentially
|
||||||
|
for inline_task in inline_tasks:
|
||||||
|
await inline_task
|
||||||
|
|
||||||
|
# Run non-inline tasks concurrently
|
||||||
|
if non_inline_tasks:
|
||||||
|
await asyncio.gather(*non_inline_tasks)
|
||||||
|
|
||||||
return managers
|
return managers
|
||||||
|
|
||||||
@ -1791,7 +1818,8 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
async callback managers, one for each LLM Run
|
async callback managers, one for each LLM Run
|
||||||
corresponding to each inner message list.
|
corresponding to each inner message list.
|
||||||
"""
|
"""
|
||||||
tasks = []
|
inline_tasks = []
|
||||||
|
non_inline_tasks = []
|
||||||
managers = []
|
managers = []
|
||||||
|
|
||||||
for message_list in messages:
|
for message_list in messages:
|
||||||
@ -1801,9 +1829,9 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
else:
|
else:
|
||||||
run_id_ = uuid.uuid4()
|
run_id_ = uuid.uuid4()
|
||||||
|
|
||||||
tasks.append(
|
for handler in self.handlers:
|
||||||
ahandle_event(
|
task = ahandle_event(
|
||||||
self.handlers,
|
[handler],
|
||||||
"on_chat_model_start",
|
"on_chat_model_start",
|
||||||
"ignore_chat_model",
|
"ignore_chat_model",
|
||||||
serialized,
|
serialized,
|
||||||
@ -1814,7 +1842,10 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
metadata=self.metadata,
|
metadata=self.metadata,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
)
|
if handler.run_inline:
|
||||||
|
inline_tasks.append(task)
|
||||||
|
else:
|
||||||
|
non_inline_tasks.append(task)
|
||||||
|
|
||||||
managers.append(
|
managers.append(
|
||||||
AsyncCallbackManagerForLLMRun(
|
AsyncCallbackManagerForLLMRun(
|
||||||
@ -1829,7 +1860,14 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
await asyncio.gather(*tasks)
|
# Run inline tasks sequentially
|
||||||
|
for task in inline_tasks:
|
||||||
|
await task
|
||||||
|
|
||||||
|
# Run non-inline tasks concurrently
|
||||||
|
if non_inline_tasks:
|
||||||
|
await asyncio.gather(*non_inline_tasks)
|
||||||
|
|
||||||
return managers
|
return managers
|
||||||
|
|
||||||
async def on_chain_start(
|
async def on_chain_start(
|
||||||
|
@ -0,0 +1,148 @@
|
|||||||
|
"""Unit tests for verifying event dispatching.
|
||||||
|
|
||||||
|
Much of this code is indirectly tested already through many end-to-end tests
|
||||||
|
that generate traces based on the callbacks. The traces are all verified
|
||||||
|
via snapshot testing (e.g., see unit tests for runnables).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import contextvars
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any, Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackHandler,
|
||||||
|
AsyncCallbackManager,
|
||||||
|
BaseCallbackHandler,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_inline_handlers_share_parent_context() -> None:
|
||||||
|
"""Verify that handlers that are configured to run_inline can update parent context.
|
||||||
|
|
||||||
|
This test was created because some of the inline handlers were getting
|
||||||
|
their own context as the handling logic was kicked off using asyncio.gather
|
||||||
|
which does not automatically propagate the parent context (by design).
|
||||||
|
|
||||||
|
This issue was affecting only a few specific handlers:
|
||||||
|
|
||||||
|
* on_llm_start
|
||||||
|
* on_chat_model_start
|
||||||
|
|
||||||
|
which in some cases were triggered with multiple prompts and as a result
|
||||||
|
triggering multiple tasks that were launched in parallel.
|
||||||
|
"""
|
||||||
|
some_var: contextvars.ContextVar[str] = contextvars.ContextVar("some_var")
|
||||||
|
|
||||||
|
class CustomHandler(AsyncCallbackHandler):
|
||||||
|
"""A handler that sets the context variable.
|
||||||
|
|
||||||
|
The handler sets the context variable to the name of the callback that was
|
||||||
|
called.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, run_inline: bool) -> None:
|
||||||
|
"""Initialize the handler."""
|
||||||
|
self.run_inline = run_inline
|
||||||
|
|
||||||
|
async def on_llm_start(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
"""Update the callstack with the name of the callback."""
|
||||||
|
some_var.set("on_llm_start")
|
||||||
|
|
||||||
|
# The manager serves as a callback dispatcher.
|
||||||
|
# It's responsible for dispatching callbacks to all registered handlers.
|
||||||
|
manager = AsyncCallbackManager(handlers=[CustomHandler(run_inline=True)])
|
||||||
|
|
||||||
|
# Check on_llm_start
|
||||||
|
some_var.set("unset")
|
||||||
|
await manager.on_llm_start({}, ["prompt 1"])
|
||||||
|
assert some_var.get() == "on_llm_start"
|
||||||
|
|
||||||
|
# Check what happens when run_inline is False
|
||||||
|
# We don't expect the context to be updated
|
||||||
|
manager2 = AsyncCallbackManager(
|
||||||
|
handlers=[
|
||||||
|
CustomHandler(run_inline=False),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
some_var.set("unset")
|
||||||
|
await manager2.on_llm_start({}, ["prompt 1"])
|
||||||
|
# Will not be updated because the handler is not inline
|
||||||
|
assert some_var.get() == "unset"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_inline_handlers_share_parent_context_multiple() -> None:
|
||||||
|
"""A slightly more complex variation of the test unit test above.
|
||||||
|
|
||||||
|
This unit test verifies that things work correctly when there are multiple prompts,
|
||||||
|
and multiple handlers that are configured to run inline.
|
||||||
|
"""
|
||||||
|
counter_var = contextvars.ContextVar("counter", default=0)
|
||||||
|
|
||||||
|
shared_stack = []
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def set_counter_var() -> Any:
|
||||||
|
token = counter_var.set(0)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
counter_var.reset(token)
|
||||||
|
|
||||||
|
class StatefulAsyncCallbackHandler(AsyncCallbackHandler):
|
||||||
|
def __init__(self, name: str, run_inline: bool = True):
|
||||||
|
self.name = name
|
||||||
|
self.run_inline = run_inline
|
||||||
|
|
||||||
|
async def on_llm_start(
|
||||||
|
self,
|
||||||
|
serialized: dict[str, Any],
|
||||||
|
prompts: list[str],
|
||||||
|
*,
|
||||||
|
run_id: UUID,
|
||||||
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
if self.name == "StateModifier":
|
||||||
|
current_counter = counter_var.get()
|
||||||
|
counter_var.set(current_counter + 1)
|
||||||
|
state = counter_var.get()
|
||||||
|
elif self.name == "StateReader":
|
||||||
|
state = counter_var.get()
|
||||||
|
else:
|
||||||
|
state = None
|
||||||
|
|
||||||
|
shared_stack.append(state)
|
||||||
|
|
||||||
|
await super().on_llm_start(
|
||||||
|
serialized,
|
||||||
|
prompts,
|
||||||
|
run_id=run_id,
|
||||||
|
parent_run_id=parent_run_id,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
handlers: list[BaseCallbackHandler] = [
|
||||||
|
StatefulAsyncCallbackHandler("StateModifier", run_inline=True),
|
||||||
|
StatefulAsyncCallbackHandler("StateReader", run_inline=True),
|
||||||
|
StatefulAsyncCallbackHandler("NonInlineHandler", run_inline=False),
|
||||||
|
]
|
||||||
|
|
||||||
|
prompts = ["Prompt1", "Prompt2", "Prompt3"]
|
||||||
|
|
||||||
|
async with set_counter_var():
|
||||||
|
shared_stack.clear()
|
||||||
|
manager = AsyncCallbackManager(handlers=handlers)
|
||||||
|
await manager.on_llm_start({}, prompts)
|
||||||
|
|
||||||
|
# Assert the order of states
|
||||||
|
states = [entry for entry in shared_stack if entry is not None]
|
||||||
|
assert states == [
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
3,
|
||||||
|
], f"Expected order of states was broken due to context loss. Got {states}"
|
Loading…
Reference in New Issue
Block a user