From 2c2daa5316a55d0b74cb1380a3820abe2b6c9545 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 22 Jul 2025 00:56:10 +0000 Subject: [PATCH 1/6] Initial plan From f365865fb5a91b419e697e2954d46eb6fb78b434 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 22 Jul 2025 01:12:47 +0000 Subject: [PATCH 2/6] Fix context preservation in shielded async callbacks Co-authored-by: mdrxy <61371264+mdrxy@users.noreply.github.com> --- libs/core/langchain_core/callbacks/manager.py | 18 +++++- .../callbacks/test_async_callback_manager.py | 63 ++++++++++++++++++- 2 files changed, 79 insertions(+), 2 deletions(-) diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 3b4f1f6e30e..487fc44aab4 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -236,7 +236,23 @@ def shielded(func: Func) -> Func: @functools.wraps(func) async def wrapped(*args: Any, **kwargs: Any) -> Any: - return await asyncio.shield(func(*args, **kwargs)) + # Capture the current context to preserve context variables + ctx = copy_context() + + # Create the coroutine + coro = func(*args, **kwargs) + + # For Python 3.11+, create task with explicit context + # For older versions, fallback to original behavior + try: + # Create a task with the captured context to preserve context variables + task = asyncio.create_task(coro, context=ctx) + return await asyncio.shield(task) + except TypeError: + # Python < 3.11 fallback - create task normally then shield + # This won't preserve context perfectly but is better than nothing + task = asyncio.create_task(coro) + return await asyncio.shield(task) return cast("Func", wrapped) diff --git a/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py b/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py index 5ae0d316e6c..00472f030a7 100644 --- a/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py +++ b/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py @@ -148,4 +148,65 @@ async def test_inline_handlers_share_parent_context_multiple() -> None: 2, 3, 3, - ], f"Expected order of states was broken due to context loss. Got {states}" + ] + + +async def test_shielded_callback_context_preservation() -> None: + """Verify that shielded callbacks preserve context variables. + + This test specifically addresses the issue where async callbacks decorated + with @shielded do not properly preserve context variables, breaking + instrumentation and other context-dependent functionality. + + The issue manifests in callbacks that use the @shielded decorator: + * on_llm_end + * on_llm_error + * on_chain_end + * on_chain_error + * And other shielded callback methods + """ + context_var: contextvars.ContextVar[str] = contextvars.ContextVar("test_context") + + class ContextTestHandler(AsyncCallbackHandler): + """Handler that reads context variables in shielded callbacks.""" + + def __init__(self) -> None: + self.run_inline = False + self.context_values = [] + + @override + async def on_llm_end(self, response: Any, **kwargs: Any) -> None: + """This method is decorated with @shielded in the run manager.""" + # This should preserve the context variable value + self.context_values.append(context_var.get("not_found")) + + @override + async def on_chain_end(self, outputs: Any, **kwargs: Any) -> None: + """This method is decorated with @shielded in the run manager.""" + # This should preserve the context variable value + self.context_values.append(context_var.get("not_found")) + + # Set up the test context + context_var.set("test_value") + handler = ContextTestHandler() + manager = AsyncCallbackManager(handlers=[handler]) + + # Create run managers that have the shielded methods + llm_managers = await manager.on_llm_start({}, ["test prompt"]) + llm_run_manager = llm_managers[0] + + chain_run_manager = await manager.on_chain_start({}, {"test": "input"}) + + # Test LLM end callback (which is shielded) + await llm_run_manager.on_llm_end({"response": "test"}) + + # Test Chain end callback (which is shielded) + await chain_run_manager.on_chain_end({"output": "test"}) + + # The context should be preserved in shielded callbacks + # This was the main issue - shielded decorators were not preserving context + assert handler.context_values == ["test_value", "test_value"], ( + f"Expected context values ['test_value', 'test_value'], " + f"but got {handler.context_values}. " + f"This indicates the shielded decorator is not preserving context variables." + ) From cd14cc5014afa12bd4bf4d5a741f107d9720dcfc Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Mon, 21 Jul 2025 21:17:24 -0400 Subject: [PATCH 3/6] format --- libs/core/langchain_core/callbacks/manager.py | 4 +-- .../callbacks/test_async_callback_manager.py | 30 +++++++++---------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 487fc44aab4..680e35e811e 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -238,10 +238,10 @@ def shielded(func: Func) -> Func: async def wrapped(*args: Any, **kwargs: Any) -> Any: # Capture the current context to preserve context variables ctx = copy_context() - + # Create the coroutine coro = func(*args, **kwargs) - + # For Python 3.11+, create task with explicit context # For older versions, fallback to original behavior try: diff --git a/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py b/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py index 00472f030a7..b453d755b84 100644 --- a/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py +++ b/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py @@ -153,56 +153,56 @@ async def test_inline_handlers_share_parent_context_multiple() -> None: async def test_shielded_callback_context_preservation() -> None: """Verify that shielded callbacks preserve context variables. - + This test specifically addresses the issue where async callbacks decorated with @shielded do not properly preserve context variables, breaking instrumentation and other context-dependent functionality. - + The issue manifests in callbacks that use the @shielded decorator: * on_llm_end - * on_llm_error + * on_llm_error * on_chain_end * on_chain_error * And other shielded callback methods """ context_var: contextvars.ContextVar[str] = contextvars.ContextVar("test_context") - + class ContextTestHandler(AsyncCallbackHandler): """Handler that reads context variables in shielded callbacks.""" - + def __init__(self) -> None: self.run_inline = False self.context_values = [] - - @override + + @override async def on_llm_end(self, response: Any, **kwargs: Any) -> None: """This method is decorated with @shielded in the run manager.""" # This should preserve the context variable value self.context_values.append(context_var.get("not_found")) - + @override async def on_chain_end(self, outputs: Any, **kwargs: Any) -> None: """This method is decorated with @shielded in the run manager.""" # This should preserve the context variable value self.context_values.append(context_var.get("not_found")) - + # Set up the test context context_var.set("test_value") handler = ContextTestHandler() manager = AsyncCallbackManager(handlers=[handler]) - + # Create run managers that have the shielded methods llm_managers = await manager.on_llm_start({}, ["test prompt"]) llm_run_manager = llm_managers[0] - + chain_run_manager = await manager.on_chain_start({}, {"test": "input"}) - + # Test LLM end callback (which is shielded) await llm_run_manager.on_llm_end({"response": "test"}) - - # Test Chain end callback (which is shielded) + + # Test Chain end callback (which is shielded) await chain_run_manager.on_chain_end({"output": "test"}) - + # The context should be preserved in shielded callbacks # This was the main issue - shielded decorators were not preserving context assert handler.context_values == ["test_value", "test_value"], ( From 0e0b1f39ca898902180f05504844ea01c4b55fe4 Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Mon, 21 Jul 2025 21:18:59 -0400 Subject: [PATCH 4/6] lint --- .../tests/unit_tests/callbacks/test_async_callback_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py b/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py index b453d755b84..65b60f13faa 100644 --- a/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py +++ b/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py @@ -172,7 +172,7 @@ async def test_shielded_callback_context_preservation() -> None: def __init__(self) -> None: self.run_inline = False - self.context_values = [] + self.context_values: list[str] = [] @override async def on_llm_end(self, response: Any, **kwargs: Any) -> None: @@ -198,7 +198,7 @@ async def test_shielded_callback_context_preservation() -> None: chain_run_manager = await manager.on_chain_start({}, {"test": "input"}) # Test LLM end callback (which is shielded) - await llm_run_manager.on_llm_end({"response": "test"}) + await llm_run_manager.on_llm_end({"response": "test"}) # type: ignore[arg-type] # Test Chain end callback (which is shielded) await chain_run_manager.on_chain_end({"output": "test"}) From 0a07cde3a21420a375fab1d6d1b43d19af9d4438 Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Mon, 21 Jul 2025 21:22:16 -0400 Subject: [PATCH 5/6] fix: add type ignore for asyncio.create_task to support Python 3.9 and 3.10 --- libs/core/langchain_core/callbacks/manager.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 680e35e811e..88aebbf3164 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -246,7 +246,8 @@ def shielded(func: Func) -> Func: # For older versions, fallback to original behavior try: # Create a task with the captured context to preserve context variables - task = asyncio.create_task(coro, context=ctx) + task = asyncio.create_task(coro, context=ctx) # type: ignore[call-arg] + # `call-arg` used to not fail 3.9 or 3.10 tests return await asyncio.shield(task) except TypeError: # Python < 3.11 fallback - create task normally then shield From a5cecf77f0925936f6524af8cc33f6c983b26c9e Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Tue, 22 Jul 2025 10:33:14 -0400 Subject: [PATCH 6/6] unused-ignore --- libs/core/langchain_core/callbacks/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 88aebbf3164..4f4a92b1711 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -246,7 +246,7 @@ def shielded(func: Func) -> Func: # For older versions, fallback to original behavior try: # Create a task with the captured context to preserve context variables - task = asyncio.create_task(coro, context=ctx) # type: ignore[call-arg] + task = asyncio.create_task(coro, context=ctx) # type: ignore[call-arg, unused-ignore] # `call-arg` used to not fail 3.9 or 3.10 tests return await asyncio.shield(task) except TypeError: