mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-30 00:04:19 +00:00
Merge a5cecf77f0
into 0e287763cd
This commit is contained in:
commit
4b8d6cb649
@ -238,7 +238,24 @@ def shielded(func: Func) -> Func:
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||
return await asyncio.shield(func(*args, **kwargs))
|
||||
# Capture the current context to preserve context variables
|
||||
ctx = copy_context()
|
||||
|
||||
# Create the coroutine
|
||||
coro = func(*args, **kwargs)
|
||||
|
||||
# For Python 3.11+, create task with explicit context
|
||||
# For older versions, fallback to original behavior
|
||||
try:
|
||||
# Create a task with the captured context to preserve context variables
|
||||
task = asyncio.create_task(coro, context=ctx) # type: ignore[call-arg, unused-ignore]
|
||||
# `call-arg` used to not fail 3.9 or 3.10 tests
|
||||
return await asyncio.shield(task)
|
||||
except TypeError:
|
||||
# Python < 3.11 fallback - create task normally then shield
|
||||
# This won't preserve context perfectly but is better than nothing
|
||||
task = asyncio.create_task(coro)
|
||||
return await asyncio.shield(task)
|
||||
|
||||
return cast("Func", wrapped)
|
||||
|
||||
|
@ -148,4 +148,65 @@ async def test_inline_handlers_share_parent_context_multiple() -> None:
|
||||
2,
|
||||
3,
|
||||
3,
|
||||
], f"Expected order of states was broken due to context loss. Got {states}"
|
||||
]
|
||||
|
||||
|
||||
async def test_shielded_callback_context_preservation() -> None:
|
||||
"""Verify that shielded callbacks preserve context variables.
|
||||
|
||||
This test specifically addresses the issue where async callbacks decorated
|
||||
with @shielded do not properly preserve context variables, breaking
|
||||
instrumentation and other context-dependent functionality.
|
||||
|
||||
The issue manifests in callbacks that use the @shielded decorator:
|
||||
* on_llm_end
|
||||
* on_llm_error
|
||||
* on_chain_end
|
||||
* on_chain_error
|
||||
* And other shielded callback methods
|
||||
"""
|
||||
context_var: contextvars.ContextVar[str] = contextvars.ContextVar("test_context")
|
||||
|
||||
class ContextTestHandler(AsyncCallbackHandler):
|
||||
"""Handler that reads context variables in shielded callbacks."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.run_inline = False
|
||||
self.context_values: list[str] = []
|
||||
|
||||
@override
|
||||
async def on_llm_end(self, response: Any, **kwargs: Any) -> None:
|
||||
"""This method is decorated with @shielded in the run manager."""
|
||||
# This should preserve the context variable value
|
||||
self.context_values.append(context_var.get("not_found"))
|
||||
|
||||
@override
|
||||
async def on_chain_end(self, outputs: Any, **kwargs: Any) -> None:
|
||||
"""This method is decorated with @shielded in the run manager."""
|
||||
# This should preserve the context variable value
|
||||
self.context_values.append(context_var.get("not_found"))
|
||||
|
||||
# Set up the test context
|
||||
context_var.set("test_value")
|
||||
handler = ContextTestHandler()
|
||||
manager = AsyncCallbackManager(handlers=[handler])
|
||||
|
||||
# Create run managers that have the shielded methods
|
||||
llm_managers = await manager.on_llm_start({}, ["test prompt"])
|
||||
llm_run_manager = llm_managers[0]
|
||||
|
||||
chain_run_manager = await manager.on_chain_start({}, {"test": "input"})
|
||||
|
||||
# Test LLM end callback (which is shielded)
|
||||
await llm_run_manager.on_llm_end({"response": "test"}) # type: ignore[arg-type]
|
||||
|
||||
# Test Chain end callback (which is shielded)
|
||||
await chain_run_manager.on_chain_end({"output": "test"})
|
||||
|
||||
# The context should be preserved in shielded callbacks
|
||||
# This was the main issue - shielded decorators were not preserving context
|
||||
assert handler.context_values == ["test_value", "test_value"], (
|
||||
f"Expected context values ['test_value', 'test_value'], "
|
||||
f"but got {handler.context_values}. "
|
||||
f"This indicates the shielded decorator is not preserving context variables."
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user