diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 487fc44aab4..680e35e811e 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -238,10 +238,10 @@ def shielded(func: Func) -> Func: async def wrapped(*args: Any, **kwargs: Any) -> Any: # Capture the current context to preserve context variables ctx = copy_context() - + # Create the coroutine coro = func(*args, **kwargs) - + # For Python 3.11+, create task with explicit context # For older versions, fallback to original behavior try: diff --git a/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py b/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py index 00472f030a7..b453d755b84 100644 --- a/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py +++ b/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py @@ -153,56 +153,56 @@ async def test_inline_handlers_share_parent_context_multiple() -> None: async def test_shielded_callback_context_preservation() -> None: """Verify that shielded callbacks preserve context variables. - + This test specifically addresses the issue where async callbacks decorated with @shielded do not properly preserve context variables, breaking instrumentation and other context-dependent functionality. - + The issue manifests in callbacks that use the @shielded decorator: * on_llm_end - * on_llm_error + * on_llm_error * on_chain_end * on_chain_error * And other shielded callback methods """ context_var: contextvars.ContextVar[str] = contextvars.ContextVar("test_context") - + class ContextTestHandler(AsyncCallbackHandler): """Handler that reads context variables in shielded callbacks.""" - + def __init__(self) -> None: self.run_inline = False self.context_values = [] - - @override + + @override async def on_llm_end(self, response: Any, **kwargs: Any) -> None: """This method is decorated with @shielded in the run manager.""" # This should preserve the context variable value self.context_values.append(context_var.get("not_found")) - + @override async def on_chain_end(self, outputs: Any, **kwargs: Any) -> None: """This method is decorated with @shielded in the run manager.""" # This should preserve the context variable value self.context_values.append(context_var.get("not_found")) - + # Set up the test context context_var.set("test_value") handler = ContextTestHandler() manager = AsyncCallbackManager(handlers=[handler]) - + # Create run managers that have the shielded methods llm_managers = await manager.on_llm_start({}, ["test prompt"]) llm_run_manager = llm_managers[0] - + chain_run_manager = await manager.on_chain_start({}, {"test": "input"}) - + # Test LLM end callback (which is shielded) await llm_run_manager.on_llm_end({"response": "test"}) - - # Test Chain end callback (which is shielded) + + # Test Chain end callback (which is shielded) await chain_run_manager.on_chain_end({"output": "test"}) - + # The context should be preserved in shielded callbacks # This was the main issue - shielded decorators were not preserving context assert handler.context_values == ["test_value", "test_value"], (