diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 56fc1bb67ba..2a74e42d9b2 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -238,7 +238,24 @@ def shielded(func: Func) -> Func: @functools.wraps(func) async def wrapped(*args: Any, **kwargs: Any) -> Any: - return await asyncio.shield(func(*args, **kwargs)) + # Capture the current context to preserve context variables + ctx = copy_context() + + # Create the coroutine + coro = func(*args, **kwargs) + + # For Python 3.11+, create task with explicit context + # For older versions, fallback to original behavior + try: + # Create a task with the captured context to preserve context variables + task = asyncio.create_task(coro, context=ctx) # type: ignore[call-arg, unused-ignore] + # `call-arg` used to not fail 3.9 or 3.10 tests + return await asyncio.shield(task) + except TypeError: + # Python < 3.11 fallback - create task normally then shield + # This won't preserve context perfectly but is better than nothing + task = asyncio.create_task(coro) + return await asyncio.shield(task) return cast("Func", wrapped) diff --git a/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py b/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py index 5ae0d316e6c..65b60f13faa 100644 --- a/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py +++ b/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py @@ -148,4 +148,65 @@ async def test_inline_handlers_share_parent_context_multiple() -> None: 2, 3, 3, - ], f"Expected order of states was broken due to context loss. Got {states}" + ] + + +async def test_shielded_callback_context_preservation() -> None: + """Verify that shielded callbacks preserve context variables. + + This test specifically addresses the issue where async callbacks decorated + with @shielded do not properly preserve context variables, breaking + instrumentation and other context-dependent functionality. + + The issue manifests in callbacks that use the @shielded decorator: + * on_llm_end + * on_llm_error + * on_chain_end + * on_chain_error + * And other shielded callback methods + """ + context_var: contextvars.ContextVar[str] = contextvars.ContextVar("test_context") + + class ContextTestHandler(AsyncCallbackHandler): + """Handler that reads context variables in shielded callbacks.""" + + def __init__(self) -> None: + self.run_inline = False + self.context_values: list[str] = [] + + @override + async def on_llm_end(self, response: Any, **kwargs: Any) -> None: + """This method is decorated with @shielded in the run manager.""" + # This should preserve the context variable value + self.context_values.append(context_var.get("not_found")) + + @override + async def on_chain_end(self, outputs: Any, **kwargs: Any) -> None: + """This method is decorated with @shielded in the run manager.""" + # This should preserve the context variable value + self.context_values.append(context_var.get("not_found")) + + # Set up the test context + context_var.set("test_value") + handler = ContextTestHandler() + manager = AsyncCallbackManager(handlers=[handler]) + + # Create run managers that have the shielded methods + llm_managers = await manager.on_llm_start({}, ["test prompt"]) + llm_run_manager = llm_managers[0] + + chain_run_manager = await manager.on_chain_start({}, {"test": "input"}) + + # Test LLM end callback (which is shielded) + await llm_run_manager.on_llm_end({"response": "test"}) # type: ignore[arg-type] + + # Test Chain end callback (which is shielded) + await chain_run_manager.on_chain_end({"output": "test"}) + + # The context should be preserved in shielded callbacks + # This was the main issue - shielded decorators were not preserving context + assert handler.context_values == ["test_value", "test_value"], ( + f"Expected context values ['test_value', 'test_value'], " + f"but got {handler.context_values}. " + f"This indicates the shielded decorator is not preserving context variables." + )