diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 05dd786591d..604e17c10af 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -1729,7 +1729,12 @@ class AsyncCallbackManager(BaseCallbackManager): to each prompt. """ - tasks = [] + inline_tasks = [] + non_inline_tasks = [] + inline_handlers = [handler for handler in self.handlers if handler.run_inline] + non_inline_handlers = [ + handler for handler in self.handlers if not handler.run_inline + ] managers = [] for prompt in prompts: @@ -1739,20 +1744,36 @@ class AsyncCallbackManager(BaseCallbackManager): else: run_id_ = uuid.uuid4() - tasks.append( - ahandle_event( - self.handlers, - "on_llm_start", - "ignore_llm", - serialized, - [prompt], - run_id=run_id_, - parent_run_id=self.parent_run_id, - tags=self.tags, - metadata=self.metadata, - **kwargs, + if inline_handlers: + inline_tasks.append( + ahandle_event( + inline_handlers, + "on_llm_start", + "ignore_llm", + serialized, + [prompt], + run_id=run_id_, + parent_run_id=self.parent_run_id, + tags=self.tags, + metadata=self.metadata, + **kwargs, + ) + ) + else: + non_inline_tasks.append( + ahandle_event( + non_inline_handlers, + "on_llm_start", + "ignore_llm", + serialized, + [prompt], + run_id=run_id_, + parent_run_id=self.parent_run_id, + tags=self.tags, + metadata=self.metadata, + **kwargs, + ) ) - ) managers.append( AsyncCallbackManagerForLLMRun( @@ -1767,7 +1788,13 @@ class AsyncCallbackManager(BaseCallbackManager): ) ) - await asyncio.gather(*tasks) + # Run inline tasks sequentially + for inline_task in inline_tasks: + await inline_task + + # Run non-inline tasks concurrently + if non_inline_tasks: + await asyncio.gather(*non_inline_tasks) return managers @@ -1791,7 +1818,8 @@ class AsyncCallbackManager(BaseCallbackManager): async callback managers, one for each LLM Run corresponding to each inner message list. """ - tasks = [] + inline_tasks = [] + non_inline_tasks = [] managers = [] for message_list in messages: @@ -1801,9 +1829,9 @@ class AsyncCallbackManager(BaseCallbackManager): else: run_id_ = uuid.uuid4() - tasks.append( - ahandle_event( - self.handlers, + for handler in self.handlers: + task = ahandle_event( + [handler], "on_chat_model_start", "ignore_chat_model", serialized, @@ -1814,7 +1842,10 @@ class AsyncCallbackManager(BaseCallbackManager): metadata=self.metadata, **kwargs, ) - ) + if handler.run_inline: + inline_tasks.append(task) + else: + non_inline_tasks.append(task) managers.append( AsyncCallbackManagerForLLMRun( @@ -1829,7 +1860,14 @@ class AsyncCallbackManager(BaseCallbackManager): ) ) - await asyncio.gather(*tasks) + # Run inline tasks sequentially + for task in inline_tasks: + await task + + # Run non-inline tasks concurrently + if non_inline_tasks: + await asyncio.gather(*non_inline_tasks) + return managers async def on_chain_start( diff --git a/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py b/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py new file mode 100644 index 00000000000..38350f9d82f --- /dev/null +++ b/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py @@ -0,0 +1,148 @@ +"""Unit tests for verifying event dispatching. + +Much of this code is indirectly tested already through many end-to-end tests +that generate traces based on the callbacks. The traces are all verified +via snapshot testing (e.g., see unit tests for runnables). +""" + +import contextvars +from contextlib import asynccontextmanager +from typing import Any, Optional +from uuid import UUID + +from langchain_core.callbacks import ( + AsyncCallbackHandler, + AsyncCallbackManager, + BaseCallbackHandler, +) + + +async def test_inline_handlers_share_parent_context() -> None: + """Verify that handlers that are configured to run_inline can update parent context. + + This test was created because some of the inline handlers were getting + their own context as the handling logic was kicked off using asyncio.gather + which does not automatically propagate the parent context (by design). + + This issue was affecting only a few specific handlers: + + * on_llm_start + * on_chat_model_start + + which in some cases were triggered with multiple prompts and as a result + triggering multiple tasks that were launched in parallel. + """ + some_var: contextvars.ContextVar[str] = contextvars.ContextVar("some_var") + + class CustomHandler(AsyncCallbackHandler): + """A handler that sets the context variable. + + The handler sets the context variable to the name of the callback that was + called. + """ + + def __init__(self, run_inline: bool) -> None: + """Initialize the handler.""" + self.run_inline = run_inline + + async def on_llm_start(self, *args: Any, **kwargs: Any) -> None: + """Update the callstack with the name of the callback.""" + some_var.set("on_llm_start") + + # The manager serves as a callback dispatcher. + # It's responsible for dispatching callbacks to all registered handlers. + manager = AsyncCallbackManager(handlers=[CustomHandler(run_inline=True)]) + + # Check on_llm_start + some_var.set("unset") + await manager.on_llm_start({}, ["prompt 1"]) + assert some_var.get() == "on_llm_start" + + # Check what happens when run_inline is False + # We don't expect the context to be updated + manager2 = AsyncCallbackManager( + handlers=[ + CustomHandler(run_inline=False), + ] + ) + + some_var.set("unset") + await manager2.on_llm_start({}, ["prompt 1"]) + # Will not be updated because the handler is not inline + assert some_var.get() == "unset" + + +async def test_inline_handlers_share_parent_context_multiple() -> None: + """A slightly more complex variation of the test unit test above. + + This unit test verifies that things work correctly when there are multiple prompts, + and multiple handlers that are configured to run inline. + """ + counter_var = contextvars.ContextVar("counter", default=0) + + shared_stack = [] + + @asynccontextmanager + async def set_counter_var() -> Any: + token = counter_var.set(0) + try: + yield + finally: + counter_var.reset(token) + + class StatefulAsyncCallbackHandler(AsyncCallbackHandler): + def __init__(self, name: str, run_inline: bool = True): + self.name = name + self.run_inline = run_inline + + async def on_llm_start( + self, + serialized: dict[str, Any], + prompts: list[str], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + if self.name == "StateModifier": + current_counter = counter_var.get() + counter_var.set(current_counter + 1) + state = counter_var.get() + elif self.name == "StateReader": + state = counter_var.get() + else: + state = None + + shared_stack.append(state) + + await super().on_llm_start( + serialized, + prompts, + run_id=run_id, + parent_run_id=parent_run_id, + **kwargs, + ) + + handlers: list[BaseCallbackHandler] = [ + StatefulAsyncCallbackHandler("StateModifier", run_inline=True), + StatefulAsyncCallbackHandler("StateReader", run_inline=True), + StatefulAsyncCallbackHandler("NonInlineHandler", run_inline=False), + ] + + prompts = ["Prompt1", "Prompt2", "Prompt3"] + + async with set_counter_var(): + shared_stack.clear() + manager = AsyncCallbackManager(handlers=handlers) + await manager.on_llm_start({}, prompts) + + # Assert the order of states + states = [entry for entry in shared_stack if entry is not None] + assert states == [ + 1, + 1, + 2, + 2, + 3, + 3, + ], f"Expected order of states was broken due to context loss. Got {states}"