Unset context after step (#30378)

While we are already careful to copy before setting the config, if other
objects hold a reference to the config or context, it wouldn't be
cleared.
This commit is contained in:
William FH 2025-03-19 11:46:23 -07:00 committed by GitHub
parent 37190881d3
commit 4130e6476b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 227 additions and 211 deletions

View File

@ -17,7 +17,6 @@ from collections.abc import (
Sequence,
)
from concurrent.futures import FIRST_COMPLETED, wait
from contextvars import copy_context
from functools import wraps
from itertools import groupby, tee
from operator import itemgetter
@ -47,7 +46,6 @@ from langchain_core.load.serializable import (
)
from langchain_core.runnables.config import (
RunnableConfig,
_set_config_context,
acall_func_with_variable_args,
call_func_with_variable_args,
ensure_config,
@ -58,6 +56,7 @@ from langchain_core.runnables.config import (
merge_configs,
patch_config,
run_in_executor,
set_config_context,
)
from langchain_core.runnables.graph import Graph
from langchain_core.runnables.utils import (
@ -1920,8 +1919,7 @@ class Runnable(Generic[Input, Output], ABC):
)
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
with set_config_context(child_config) as context:
output = cast(
Output,
context.run(
@ -1970,8 +1968,7 @@ class Runnable(Generic[Input, Output], ABC):
)
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
with set_config_context(child_config) as context:
coro = acall_func_with_variable_args(
func, input, config, run_manager, **kwargs
)
@ -2182,8 +2179,7 @@ class Runnable(Generic[Input, Output], ABC):
kwargs["config"] = child_config
if accepts_run_manager(transformer):
kwargs["run_manager"] = run_manager
context = copy_context()
context.run(_set_config_context, child_config)
with set_config_context(child_config) as context:
iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
if stream_handler := next(
(
@ -2195,7 +2191,9 @@ class Runnable(Generic[Input, Output], ABC):
None,
):
# populates streamed_output in astream_log() output if needed
iterator = stream_handler.tap_output_iter(run_manager.run_id, iterator)
iterator = stream_handler.tap_output_iter(
run_manager.run_id, iterator
)
try:
while True:
chunk: Output = context.run(next, iterator) # type: ignore
@ -2283,8 +2281,7 @@ class Runnable(Generic[Input, Output], ABC):
kwargs["config"] = child_config
if accepts_run_manager(transformer):
kwargs["run_manager"] = run_manager
context = copy_context()
context.run(_set_config_context, child_config)
with set_config_context(child_config) as context:
iterator_ = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
if stream_handler := next(
@ -3021,8 +3018,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i + 1}")
)
context = copy_context()
context.run(_set_config_context, config)
with set_config_context(config) as context:
if i == 0:
input = context.run(step.invoke, input, config, **kwargs)
else:
@ -3061,8 +3057,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i + 1}")
)
context = copy_context()
context.run(_set_config_context, config)
with set_config_context(config) as context:
if i == 0:
part = functools.partial(step.ainvoke, input, config, **kwargs)
else:
@ -3713,8 +3708,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
# mark each step as a child run
callbacks=run_manager.get_child(f"map:key:{key}"),
)
context = copy_context()
context.run(_set_config_context, child_config)
with set_config_context(child_config) as context:
return context.run(
step.invoke,
input,
@ -3764,8 +3758,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
config,
callbacks=run_manager.get_child(f"map:key:{key}"),
)
context = copy_context()
context.run(_set_config_context, child_config)
with set_config_context(child_config) as context:
if asyncio_accepts_context():
return await asyncio.create_task( # type: ignore
step.ainvoke(input, child_config), context=context

View File

@ -6,7 +6,7 @@ import warnings
from collections.abc import Awaitable, Generator, Iterable, Iterator, Sequence
from concurrent.futures import Executor, Future, ThreadPoolExecutor
from contextlib import contextmanager
from contextvars import ContextVar, copy_context
from contextvars import Context, ContextVar, Token, copy_context
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, cast
@ -115,7 +115,10 @@ var_child_runnable_config: ContextVar[RunnableConfig | None] = ContextVar(
)
def _set_config_context(config: RunnableConfig) -> None:
# This is imported and used in langgraph, so don't break.
def _set_config_context(
config: RunnableConfig,
) -> tuple[Token[Optional[RunnableConfig]], Optional[dict[str, Any]]]:
"""Set the child Runnable config + tracing context.
Args:
@ -123,7 +126,8 @@ def _set_config_context(config: RunnableConfig) -> None:
"""
from langchain_core.tracers.langchain import LangChainTracer
var_child_runnable_config.set(config)
config_token = var_child_runnable_config.set(config)
current_context = None
if (
(callbacks := config.get("callbacks"))
and (
@ -141,9 +145,32 @@ def _set_config_context(config: RunnableConfig) -> None:
)
and (run := tracer.run_map.get(str(parent_run_id)))
):
from langsmith.run_helpers import _set_tracing_context, get_tracing_context
current_context = get_tracing_context()
_set_tracing_context({"parent": run})
return config_token, current_context
@contextmanager
def set_config_context(
config: RunnableConfig, ctx: Optional[Context] = None
) -> Generator[Context, None, None]:
"""Set the child Runnable config + tracing context.
Args:
config (RunnableConfig): The config to set.
"""
from langsmith.run_helpers import _set_tracing_context
_set_tracing_context({"parent": run})
ctx = ctx if ctx is not None else copy_context()
config_token, current_context = ctx.run(_set_config_context, config)
try:
yield ctx
finally:
ctx.run(var_child_runnable_config.reset, config_token)
if current_context is not None:
ctx.run(_set_tracing_context, current_context)
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:

View File

@ -2,7 +2,6 @@ import asyncio
import inspect
import typing
from collections.abc import AsyncIterator, Iterator, Sequence
from contextvars import copy_context
from functools import wraps
from typing import (
TYPE_CHECKING,
@ -18,12 +17,12 @@ from typing_extensions import override
from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.runnables.config import (
RunnableConfig,
_set_config_context,
ensure_config,
get_async_callback_manager_for_config,
get_callback_manager_for_config,
get_config_list,
patch_config,
set_config_context,
)
from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
@ -174,8 +173,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
if self.exception_key and last_error is not None:
input[self.exception_key] = last_error
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
with set_config_context(child_config) as context:
output = context.run(
runnable.invoke,
input,
@ -228,9 +226,8 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
if self.exception_key and last_error is not None:
input[self.exception_key] = last_error
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
coro = runnable.ainvoke(input, child_config, **kwargs)
with set_config_context(child_config) as context:
coro = context.run(runnable.ainvoke, input, config, **kwargs)
if asyncio_accepts_context():
output = await asyncio.create_task(coro, context=context) # type: ignore
else:
@ -475,8 +472,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
if self.exception_key and last_error is not None:
input[self.exception_key] = last_error
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
with set_config_context(child_config) as context:
stream = context.run(
runnable.stream,
input,
@ -539,8 +535,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
if self.exception_key and last_error is not None:
input[self.exception_key] = last_error
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
with set_config_context(child_config) as context:
stream = runnable.astream(
input,
child_config,

View File

@ -6,7 +6,6 @@ import inspect
import json
import warnings
from abc import ABC, abstractmethod
from contextvars import copy_context
from inspect import signature
from typing import (
TYPE_CHECKING,
@ -52,7 +51,7 @@ from langchain_core.runnables import (
patch_config,
run_in_executor,
)
from langchain_core.runnables.config import _set_config_context
from langchain_core.runnables.config import set_config_context
from langchain_core.runnables.utils import asyncio_accepts_context
from langchain_core.utils.function_calling import (
_parse_google_docstring,
@ -722,9 +721,10 @@ class ChildTool(BaseTool):
error_to_raise: Union[Exception, KeyboardInterrupt, None] = None
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
with set_config_context(child_config) as context:
tool_args, tool_kwargs = self._to_args_and_kwargs(
tool_input, tool_call_id
)
if signature(self._run).parameters.get("run_manager"):
tool_kwargs = tool_kwargs | {"run_manager": run_manager}
if config_param := _get_runnable_config_param(self._run):
@ -832,8 +832,7 @@ class ChildTool(BaseTool):
try:
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
with set_config_context(child_config) as context:
func_to_check = (
self._run if self.__class__._arun is BaseTool._arun else self._arun
)
@ -842,7 +841,7 @@ class ChildTool(BaseTool):
if config_param := _get_runnable_config_param(func_to_check):
tool_kwargs[config_param] = config
coro = context.run(self._arun, *tool_args, **tool_kwargs)
coro = self._arun(*tool_args, **tool_kwargs)
if asyncio_accepts_context():
response = await asyncio.create_task(coro, context=context) # type: ignore
else:

View File

@ -89,11 +89,11 @@ def tracing_v2_enabled(
tags=tags,
client=client,
)
token = tracing_v2_callback_var.set(cb)
try:
tracing_v2_callback_var.set(cb)
yield cb
finally:
tracing_v2_callback_var.set(None)
tracing_v2_callback_var.reset(token)
@contextmanager
@ -109,9 +109,11 @@ def collect_runs() -> Generator[RunCollectorCallbackHandler, None, None]:
run_id = runs_cb.traced_runs[0].id
"""
cb = RunCollectorCallbackHandler()
run_collector_var.set(cb)
token = run_collector_var.set(cb)
try:
yield cb
run_collector_var.set(None)
finally:
run_collector_var.reset(token)
def _get_trace_callbacks(