mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 13:27:36 +00:00
format
This commit is contained in:
parent
f365865fb5
commit
cd14cc5014
@ -238,10 +238,10 @@ def shielded(func: Func) -> Func:
|
|||||||
async def wrapped(*args: Any, **kwargs: Any) -> Any:
|
async def wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||||
# Capture the current context to preserve context variables
|
# Capture the current context to preserve context variables
|
||||||
ctx = copy_context()
|
ctx = copy_context()
|
||||||
|
|
||||||
# Create the coroutine
|
# Create the coroutine
|
||||||
coro = func(*args, **kwargs)
|
coro = func(*args, **kwargs)
|
||||||
|
|
||||||
# For Python 3.11+, create task with explicit context
|
# For Python 3.11+, create task with explicit context
|
||||||
# For older versions, fallback to original behavior
|
# For older versions, fallback to original behavior
|
||||||
try:
|
try:
|
||||||
|
@ -153,56 +153,56 @@ async def test_inline_handlers_share_parent_context_multiple() -> None:
|
|||||||
|
|
||||||
async def test_shielded_callback_context_preservation() -> None:
|
async def test_shielded_callback_context_preservation() -> None:
|
||||||
"""Verify that shielded callbacks preserve context variables.
|
"""Verify that shielded callbacks preserve context variables.
|
||||||
|
|
||||||
This test specifically addresses the issue where async callbacks decorated
|
This test specifically addresses the issue where async callbacks decorated
|
||||||
with @shielded do not properly preserve context variables, breaking
|
with @shielded do not properly preserve context variables, breaking
|
||||||
instrumentation and other context-dependent functionality.
|
instrumentation and other context-dependent functionality.
|
||||||
|
|
||||||
The issue manifests in callbacks that use the @shielded decorator:
|
The issue manifests in callbacks that use the @shielded decorator:
|
||||||
* on_llm_end
|
* on_llm_end
|
||||||
* on_llm_error
|
* on_llm_error
|
||||||
* on_chain_end
|
* on_chain_end
|
||||||
* on_chain_error
|
* on_chain_error
|
||||||
* And other shielded callback methods
|
* And other shielded callback methods
|
||||||
"""
|
"""
|
||||||
context_var: contextvars.ContextVar[str] = contextvars.ContextVar("test_context")
|
context_var: contextvars.ContextVar[str] = contextvars.ContextVar("test_context")
|
||||||
|
|
||||||
class ContextTestHandler(AsyncCallbackHandler):
|
class ContextTestHandler(AsyncCallbackHandler):
|
||||||
"""Handler that reads context variables in shielded callbacks."""
|
"""Handler that reads context variables in shielded callbacks."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.run_inline = False
|
self.run_inline = False
|
||||||
self.context_values = []
|
self.context_values = []
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def on_llm_end(self, response: Any, **kwargs: Any) -> None:
|
async def on_llm_end(self, response: Any, **kwargs: Any) -> None:
|
||||||
"""This method is decorated with @shielded in the run manager."""
|
"""This method is decorated with @shielded in the run manager."""
|
||||||
# This should preserve the context variable value
|
# This should preserve the context variable value
|
||||||
self.context_values.append(context_var.get("not_found"))
|
self.context_values.append(context_var.get("not_found"))
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def on_chain_end(self, outputs: Any, **kwargs: Any) -> None:
|
async def on_chain_end(self, outputs: Any, **kwargs: Any) -> None:
|
||||||
"""This method is decorated with @shielded in the run manager."""
|
"""This method is decorated with @shielded in the run manager."""
|
||||||
# This should preserve the context variable value
|
# This should preserve the context variable value
|
||||||
self.context_values.append(context_var.get("not_found"))
|
self.context_values.append(context_var.get("not_found"))
|
||||||
|
|
||||||
# Set up the test context
|
# Set up the test context
|
||||||
context_var.set("test_value")
|
context_var.set("test_value")
|
||||||
handler = ContextTestHandler()
|
handler = ContextTestHandler()
|
||||||
manager = AsyncCallbackManager(handlers=[handler])
|
manager = AsyncCallbackManager(handlers=[handler])
|
||||||
|
|
||||||
# Create run managers that have the shielded methods
|
# Create run managers that have the shielded methods
|
||||||
llm_managers = await manager.on_llm_start({}, ["test prompt"])
|
llm_managers = await manager.on_llm_start({}, ["test prompt"])
|
||||||
llm_run_manager = llm_managers[0]
|
llm_run_manager = llm_managers[0]
|
||||||
|
|
||||||
chain_run_manager = await manager.on_chain_start({}, {"test": "input"})
|
chain_run_manager = await manager.on_chain_start({}, {"test": "input"})
|
||||||
|
|
||||||
# Test LLM end callback (which is shielded)
|
# Test LLM end callback (which is shielded)
|
||||||
await llm_run_manager.on_llm_end({"response": "test"})
|
await llm_run_manager.on_llm_end({"response": "test"})
|
||||||
|
|
||||||
# Test Chain end callback (which is shielded)
|
# Test Chain end callback (which is shielded)
|
||||||
await chain_run_manager.on_chain_end({"output": "test"})
|
await chain_run_manager.on_chain_end({"output": "test"})
|
||||||
|
|
||||||
# The context should be preserved in shielded callbacks
|
# The context should be preserved in shielded callbacks
|
||||||
# This was the main issue - shielded decorators were not preserving context
|
# This was the main issue - shielded decorators were not preserving context
|
||||||
assert handler.context_values == ["test_value", "test_value"], (
|
assert handler.context_values == ["test_value", "test_value"], (
|
||||||
|
Loading…
Reference in New Issue
Block a user