This commit is contained in:
Mason Daugherty 2025-07-21 21:17:24 -04:00
parent f365865fb5
commit cd14cc5014
No known key found for this signature in database
2 changed files with 17 additions and 17 deletions

View File

@ -238,10 +238,10 @@ def shielded(func: Func) -> Func:
async def wrapped(*args: Any, **kwargs: Any) -> Any: async def wrapped(*args: Any, **kwargs: Any) -> Any:
# Capture the current context to preserve context variables # Capture the current context to preserve context variables
ctx = copy_context() ctx = copy_context()
# Create the coroutine # Create the coroutine
coro = func(*args, **kwargs) coro = func(*args, **kwargs)
# For Python 3.11+, create task with explicit context # For Python 3.11+, create task with explicit context
# For older versions, fallback to original behavior # For older versions, fallback to original behavior
try: try:

View File

@ -153,56 +153,56 @@ async def test_inline_handlers_share_parent_context_multiple() -> None:
async def test_shielded_callback_context_preservation() -> None: async def test_shielded_callback_context_preservation() -> None:
"""Verify that shielded callbacks preserve context variables. """Verify that shielded callbacks preserve context variables.
This test specifically addresses the issue where async callbacks decorated This test specifically addresses the issue where async callbacks decorated
with @shielded do not properly preserve context variables, breaking with @shielded do not properly preserve context variables, breaking
instrumentation and other context-dependent functionality. instrumentation and other context-dependent functionality.
The issue manifests in callbacks that use the @shielded decorator: The issue manifests in callbacks that use the @shielded decorator:
* on_llm_end * on_llm_end
* on_llm_error * on_llm_error
* on_chain_end * on_chain_end
* on_chain_error * on_chain_error
* And other shielded callback methods * And other shielded callback methods
""" """
context_var: contextvars.ContextVar[str] = contextvars.ContextVar("test_context") context_var: contextvars.ContextVar[str] = contextvars.ContextVar("test_context")
class ContextTestHandler(AsyncCallbackHandler): class ContextTestHandler(AsyncCallbackHandler):
"""Handler that reads context variables in shielded callbacks.""" """Handler that reads context variables in shielded callbacks."""
def __init__(self) -> None: def __init__(self) -> None:
self.run_inline = False self.run_inline = False
self.context_values = [] self.context_values = []
@override @override
async def on_llm_end(self, response: Any, **kwargs: Any) -> None: async def on_llm_end(self, response: Any, **kwargs: Any) -> None:
"""This method is decorated with @shielded in the run manager.""" """This method is decorated with @shielded in the run manager."""
# This should preserve the context variable value # This should preserve the context variable value
self.context_values.append(context_var.get("not_found")) self.context_values.append(context_var.get("not_found"))
@override @override
async def on_chain_end(self, outputs: Any, **kwargs: Any) -> None: async def on_chain_end(self, outputs: Any, **kwargs: Any) -> None:
"""This method is decorated with @shielded in the run manager.""" """This method is decorated with @shielded in the run manager."""
# This should preserve the context variable value # This should preserve the context variable value
self.context_values.append(context_var.get("not_found")) self.context_values.append(context_var.get("not_found"))
# Set up the test context # Set up the test context
context_var.set("test_value") context_var.set("test_value")
handler = ContextTestHandler() handler = ContextTestHandler()
manager = AsyncCallbackManager(handlers=[handler]) manager = AsyncCallbackManager(handlers=[handler])
# Create run managers that have the shielded methods # Create run managers that have the shielded methods
llm_managers = await manager.on_llm_start({}, ["test prompt"]) llm_managers = await manager.on_llm_start({}, ["test prompt"])
llm_run_manager = llm_managers[0] llm_run_manager = llm_managers[0]
chain_run_manager = await manager.on_chain_start({}, {"test": "input"}) chain_run_manager = await manager.on_chain_start({}, {"test": "input"})
# Test LLM end callback (which is shielded) # Test LLM end callback (which is shielded)
await llm_run_manager.on_llm_end({"response": "test"}) await llm_run_manager.on_llm_end({"response": "test"})
# Test Chain end callback (which is shielded) # Test Chain end callback (which is shielded)
await chain_run_manager.on_chain_end({"output": "test"}) await chain_run_manager.on_chain_end({"output": "test"})
# The context should be preserved in shielded callbacks # The context should be preserved in shielded callbacks
# This was the main issue - shielded decorators were not preserving context # This was the main issue - shielded decorators were not preserving context
assert handler.context_values == ["test_value", "test_value"], ( assert handler.context_values == ["test_value", "test_value"], (