Unset context after step (#30378)

While we are already careful to copy before setting the config, if other
objects hold a reference to the config or context, it wouldn't be
cleared.
This commit is contained in:
William FH 2025-03-19 11:46:23 -07:00 committed by GitHub
parent 37190881d3
commit 4130e6476b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 227 additions and 211 deletions

View File

@ -17,7 +17,6 @@ from collections.abc import (
Sequence, Sequence,
) )
from concurrent.futures import FIRST_COMPLETED, wait from concurrent.futures import FIRST_COMPLETED, wait
from contextvars import copy_context
from functools import wraps from functools import wraps
from itertools import groupby, tee from itertools import groupby, tee
from operator import itemgetter from operator import itemgetter
@ -47,7 +46,6 @@ from langchain_core.load.serializable import (
) )
from langchain_core.runnables.config import ( from langchain_core.runnables.config import (
RunnableConfig, RunnableConfig,
_set_config_context,
acall_func_with_variable_args, acall_func_with_variable_args,
call_func_with_variable_args, call_func_with_variable_args,
ensure_config, ensure_config,
@ -58,6 +56,7 @@ from langchain_core.runnables.config import (
merge_configs, merge_configs,
patch_config, patch_config,
run_in_executor, run_in_executor,
set_config_context,
) )
from langchain_core.runnables.graph import Graph from langchain_core.runnables.graph import Graph
from langchain_core.runnables.utils import ( from langchain_core.runnables.utils import (
@ -1920,8 +1919,7 @@ class Runnable(Generic[Input, Output], ABC):
) )
try: try:
child_config = patch_config(config, callbacks=run_manager.get_child()) child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context() with set_config_context(child_config) as context:
context.run(_set_config_context, child_config)
output = cast( output = cast(
Output, Output,
context.run( context.run(
@ -1970,8 +1968,7 @@ class Runnable(Generic[Input, Output], ABC):
) )
try: try:
child_config = patch_config(config, callbacks=run_manager.get_child()) child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context() with set_config_context(child_config) as context:
context.run(_set_config_context, child_config)
coro = acall_func_with_variable_args( coro = acall_func_with_variable_args(
func, input, config, run_manager, **kwargs func, input, config, run_manager, **kwargs
) )
@ -2182,8 +2179,7 @@ class Runnable(Generic[Input, Output], ABC):
kwargs["config"] = child_config kwargs["config"] = child_config
if accepts_run_manager(transformer): if accepts_run_manager(transformer):
kwargs["run_manager"] = run_manager kwargs["run_manager"] = run_manager
context = copy_context() with set_config_context(child_config) as context:
context.run(_set_config_context, child_config)
iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type] iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
if stream_handler := next( if stream_handler := next(
( (
@ -2195,7 +2191,9 @@ class Runnable(Generic[Input, Output], ABC):
None, None,
): ):
# populates streamed_output in astream_log() output if needed # populates streamed_output in astream_log() output if needed
iterator = stream_handler.tap_output_iter(run_manager.run_id, iterator) iterator = stream_handler.tap_output_iter(
run_manager.run_id, iterator
)
try: try:
while True: while True:
chunk: Output = context.run(next, iterator) # type: ignore chunk: Output = context.run(next, iterator) # type: ignore
@ -2283,8 +2281,7 @@ class Runnable(Generic[Input, Output], ABC):
kwargs["config"] = child_config kwargs["config"] = child_config
if accepts_run_manager(transformer): if accepts_run_manager(transformer):
kwargs["run_manager"] = run_manager kwargs["run_manager"] = run_manager
context = copy_context() with set_config_context(child_config) as context:
context.run(_set_config_context, child_config)
iterator_ = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type] iterator_ = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
if stream_handler := next( if stream_handler := next(
@ -3021,8 +3018,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
config = patch_config( config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i + 1}") config, callbacks=run_manager.get_child(f"seq:step:{i + 1}")
) )
context = copy_context() with set_config_context(config) as context:
context.run(_set_config_context, config)
if i == 0: if i == 0:
input = context.run(step.invoke, input, config, **kwargs) input = context.run(step.invoke, input, config, **kwargs)
else: else:
@ -3061,8 +3057,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
config = patch_config( config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i + 1}") config, callbacks=run_manager.get_child(f"seq:step:{i + 1}")
) )
context = copy_context() with set_config_context(config) as context:
context.run(_set_config_context, config)
if i == 0: if i == 0:
part = functools.partial(step.ainvoke, input, config, **kwargs) part = functools.partial(step.ainvoke, input, config, **kwargs)
else: else:
@ -3713,8 +3708,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
# mark each step as a child run # mark each step as a child run
callbacks=run_manager.get_child(f"map:key:{key}"), callbacks=run_manager.get_child(f"map:key:{key}"),
) )
context = copy_context() with set_config_context(child_config) as context:
context.run(_set_config_context, child_config)
return context.run( return context.run(
step.invoke, step.invoke,
input, input,
@ -3764,8 +3758,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
config, config,
callbacks=run_manager.get_child(f"map:key:{key}"), callbacks=run_manager.get_child(f"map:key:{key}"),
) )
context = copy_context() with set_config_context(child_config) as context:
context.run(_set_config_context, child_config)
if asyncio_accepts_context(): if asyncio_accepts_context():
return await asyncio.create_task( # type: ignore return await asyncio.create_task( # type: ignore
step.ainvoke(input, child_config), context=context step.ainvoke(input, child_config), context=context

View File

@ -6,7 +6,7 @@ import warnings
from collections.abc import Awaitable, Generator, Iterable, Iterator, Sequence from collections.abc import Awaitable, Generator, Iterable, Iterator, Sequence
from concurrent.futures import Executor, Future, ThreadPoolExecutor from concurrent.futures import Executor, Future, ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from contextvars import ContextVar, copy_context from contextvars import Context, ContextVar, Token, copy_context
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, cast from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, cast
@ -115,7 +115,10 @@ var_child_runnable_config: ContextVar[RunnableConfig | None] = ContextVar(
) )
def _set_config_context(config: RunnableConfig) -> None: # This is imported and used in langgraph, so don't break.
def _set_config_context(
config: RunnableConfig,
) -> tuple[Token[Optional[RunnableConfig]], Optional[dict[str, Any]]]:
"""Set the child Runnable config + tracing context. """Set the child Runnable config + tracing context.
Args: Args:
@ -123,7 +126,8 @@ def _set_config_context(config: RunnableConfig) -> None:
""" """
from langchain_core.tracers.langchain import LangChainTracer from langchain_core.tracers.langchain import LangChainTracer
var_child_runnable_config.set(config) config_token = var_child_runnable_config.set(config)
current_context = None
if ( if (
(callbacks := config.get("callbacks")) (callbacks := config.get("callbacks"))
and ( and (
@ -141,9 +145,32 @@ def _set_config_context(config: RunnableConfig) -> None:
) )
and (run := tracer.run_map.get(str(parent_run_id))) and (run := tracer.run_map.get(str(parent_run_id)))
): ):
from langsmith.run_helpers import _set_tracing_context, get_tracing_context
current_context = get_tracing_context()
_set_tracing_context({"parent": run})
return config_token, current_context
@contextmanager
def set_config_context(
config: RunnableConfig, ctx: Optional[Context] = None
) -> Generator[Context, None, None]:
"""Set the child Runnable config + tracing context.
Args:
config (RunnableConfig): The config to set.
"""
from langsmith.run_helpers import _set_tracing_context from langsmith.run_helpers import _set_tracing_context
_set_tracing_context({"parent": run}) ctx = ctx if ctx is not None else copy_context()
config_token, current_context = ctx.run(_set_config_context, config)
try:
yield ctx
finally:
ctx.run(var_child_runnable_config.reset, config_token)
if current_context is not None:
ctx.run(_set_tracing_context, current_context)
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:

View File

@ -2,7 +2,6 @@ import asyncio
import inspect import inspect
import typing import typing
from collections.abc import AsyncIterator, Iterator, Sequence from collections.abc import AsyncIterator, Iterator, Sequence
from contextvars import copy_context
from functools import wraps from functools import wraps
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -18,12 +17,12 @@ from typing_extensions import override
from langchain_core.runnables.base import Runnable, RunnableSerializable from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.runnables.config import ( from langchain_core.runnables.config import (
RunnableConfig, RunnableConfig,
_set_config_context,
ensure_config, ensure_config,
get_async_callback_manager_for_config, get_async_callback_manager_for_config,
get_callback_manager_for_config, get_callback_manager_for_config,
get_config_list, get_config_list,
patch_config, patch_config,
set_config_context,
) )
from langchain_core.runnables.utils import ( from langchain_core.runnables.utils import (
ConfigurableFieldSpec, ConfigurableFieldSpec,
@ -174,8 +173,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
if self.exception_key and last_error is not None: if self.exception_key and last_error is not None:
input[self.exception_key] = last_error input[self.exception_key] = last_error
child_config = patch_config(config, callbacks=run_manager.get_child()) child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context() with set_config_context(child_config) as context:
context.run(_set_config_context, child_config)
output = context.run( output = context.run(
runnable.invoke, runnable.invoke,
input, input,
@ -228,9 +226,8 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
if self.exception_key and last_error is not None: if self.exception_key and last_error is not None:
input[self.exception_key] = last_error input[self.exception_key] = last_error
child_config = patch_config(config, callbacks=run_manager.get_child()) child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context() with set_config_context(child_config) as context:
context.run(_set_config_context, child_config) coro = context.run(runnable.ainvoke, input, config, **kwargs)
coro = runnable.ainvoke(input, child_config, **kwargs)
if asyncio_accepts_context(): if asyncio_accepts_context():
output = await asyncio.create_task(coro, context=context) # type: ignore output = await asyncio.create_task(coro, context=context) # type: ignore
else: else:
@ -475,8 +472,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
if self.exception_key and last_error is not None: if self.exception_key and last_error is not None:
input[self.exception_key] = last_error input[self.exception_key] = last_error
child_config = patch_config(config, callbacks=run_manager.get_child()) child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context() with set_config_context(child_config) as context:
context.run(_set_config_context, child_config)
stream = context.run( stream = context.run(
runnable.stream, runnable.stream,
input, input,
@ -539,8 +535,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
if self.exception_key and last_error is not None: if self.exception_key and last_error is not None:
input[self.exception_key] = last_error input[self.exception_key] = last_error
child_config = patch_config(config, callbacks=run_manager.get_child()) child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context() with set_config_context(child_config) as context:
context.run(_set_config_context, child_config)
stream = runnable.astream( stream = runnable.astream(
input, input,
child_config, child_config,

View File

@ -6,7 +6,6 @@ import inspect
import json import json
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextvars import copy_context
from inspect import signature from inspect import signature
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -52,7 +51,7 @@ from langchain_core.runnables import (
patch_config, patch_config,
run_in_executor, run_in_executor,
) )
from langchain_core.runnables.config import _set_config_context from langchain_core.runnables.config import set_config_context
from langchain_core.runnables.utils import asyncio_accepts_context from langchain_core.runnables.utils import asyncio_accepts_context
from langchain_core.utils.function_calling import ( from langchain_core.utils.function_calling import (
_parse_google_docstring, _parse_google_docstring,
@ -722,9 +721,10 @@ class ChildTool(BaseTool):
error_to_raise: Union[Exception, KeyboardInterrupt, None] = None error_to_raise: Union[Exception, KeyboardInterrupt, None] = None
try: try:
child_config = patch_config(config, callbacks=run_manager.get_child()) child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context() with set_config_context(child_config) as context:
context.run(_set_config_context, child_config) tool_args, tool_kwargs = self._to_args_and_kwargs(
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id) tool_input, tool_call_id
)
if signature(self._run).parameters.get("run_manager"): if signature(self._run).parameters.get("run_manager"):
tool_kwargs = tool_kwargs | {"run_manager": run_manager} tool_kwargs = tool_kwargs | {"run_manager": run_manager}
if config_param := _get_runnable_config_param(self._run): if config_param := _get_runnable_config_param(self._run):
@ -832,8 +832,7 @@ class ChildTool(BaseTool):
try: try:
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id) tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
child_config = patch_config(config, callbacks=run_manager.get_child()) child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context() with set_config_context(child_config) as context:
context.run(_set_config_context, child_config)
func_to_check = ( func_to_check = (
self._run if self.__class__._arun is BaseTool._arun else self._arun self._run if self.__class__._arun is BaseTool._arun else self._arun
) )
@ -842,7 +841,7 @@ class ChildTool(BaseTool):
if config_param := _get_runnable_config_param(func_to_check): if config_param := _get_runnable_config_param(func_to_check):
tool_kwargs[config_param] = config tool_kwargs[config_param] = config
coro = context.run(self._arun, *tool_args, **tool_kwargs) coro = self._arun(*tool_args, **tool_kwargs)
if asyncio_accepts_context(): if asyncio_accepts_context():
response = await asyncio.create_task(coro, context=context) # type: ignore response = await asyncio.create_task(coro, context=context) # type: ignore
else: else:

View File

@ -89,11 +89,11 @@ def tracing_v2_enabled(
tags=tags, tags=tags,
client=client, client=client,
) )
token = tracing_v2_callback_var.set(cb)
try: try:
tracing_v2_callback_var.set(cb)
yield cb yield cb
finally: finally:
tracing_v2_callback_var.set(None) tracing_v2_callback_var.reset(token)
@contextmanager @contextmanager
@ -109,9 +109,11 @@ def collect_runs() -> Generator[RunCollectorCallbackHandler, None, None]:
run_id = runs_cb.traced_runs[0].id run_id = runs_cb.traced_runs[0].id
""" """
cb = RunCollectorCallbackHandler() cb = RunCollectorCallbackHandler()
run_collector_var.set(cb) token = run_collector_var.set(cb)
try:
yield cb yield cb
run_collector_var.set(None) finally:
run_collector_var.reset(token)
def _get_trace_callbacks( def _get_trace_callbacks(