Unset context after step (#30378)

While we are already careful to copy before setting the config, if other
objects hold a reference to the config or context, it wouldn't be
cleared.
This commit is contained in:
William FH 2025-03-19 11:46:23 -07:00 committed by GitHub
parent 37190881d3
commit 4130e6476b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 227 additions and 211 deletions

View File

@ -17,7 +17,6 @@ from collections.abc import (
Sequence, Sequence,
) )
from concurrent.futures import FIRST_COMPLETED, wait from concurrent.futures import FIRST_COMPLETED, wait
from contextvars import copy_context
from functools import wraps from functools import wraps
from itertools import groupby, tee from itertools import groupby, tee
from operator import itemgetter from operator import itemgetter
@ -47,7 +46,6 @@ from langchain_core.load.serializable import (
) )
from langchain_core.runnables.config import ( from langchain_core.runnables.config import (
RunnableConfig, RunnableConfig,
_set_config_context,
acall_func_with_variable_args, acall_func_with_variable_args,
call_func_with_variable_args, call_func_with_variable_args,
ensure_config, ensure_config,
@ -58,6 +56,7 @@ from langchain_core.runnables.config import (
merge_configs, merge_configs,
patch_config, patch_config,
run_in_executor, run_in_executor,
set_config_context,
) )
from langchain_core.runnables.graph import Graph from langchain_core.runnables.graph import Graph
from langchain_core.runnables.utils import ( from langchain_core.runnables.utils import (
@ -1920,19 +1919,18 @@ class Runnable(Generic[Input, Output], ABC):
) )
try: try:
child_config = patch_config(config, callbacks=run_manager.get_child()) child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context() with set_config_context(child_config) as context:
context.run(_set_config_context, child_config) output = cast(
output = cast( Output,
Output, context.run(
context.run( call_func_with_variable_args, # type: ignore[arg-type]
call_func_with_variable_args, # type: ignore[arg-type] func, # type: ignore[arg-type]
func, # type: ignore[arg-type] input, # type: ignore[arg-type]
input, # type: ignore[arg-type] config,
config, run_manager,
run_manager, **kwargs,
**kwargs, ),
), )
)
except BaseException as e: except BaseException as e:
run_manager.on_chain_error(e) run_manager.on_chain_error(e)
raise raise
@ -1970,15 +1968,14 @@ class Runnable(Generic[Input, Output], ABC):
) )
try: try:
child_config = patch_config(config, callbacks=run_manager.get_child()) child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context() with set_config_context(child_config) as context:
context.run(_set_config_context, child_config) coro = acall_func_with_variable_args(
coro = acall_func_with_variable_args( func, input, config, run_manager, **kwargs
func, input, config, run_manager, **kwargs )
) if asyncio_accepts_context():
if asyncio_accepts_context(): output: Output = await asyncio.create_task(coro, context=context) # type: ignore
output: Output = await asyncio.create_task(coro, context=context) # type: ignore else:
else: output = await coro
output = await coro
except BaseException as e: except BaseException as e:
await run_manager.on_chain_error(e) await run_manager.on_chain_error(e)
raise raise
@ -2182,49 +2179,50 @@ class Runnable(Generic[Input, Output], ABC):
kwargs["config"] = child_config kwargs["config"] = child_config
if accepts_run_manager(transformer): if accepts_run_manager(transformer):
kwargs["run_manager"] = run_manager kwargs["run_manager"] = run_manager
context = copy_context() with set_config_context(child_config) as context:
context.run(_set_config_context, child_config) iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type] if stream_handler := next(
if stream_handler := next( (
( cast(_StreamingCallbackHandler, h)
cast(_StreamingCallbackHandler, h) for h in run_manager.handlers
for h in run_manager.handlers # instance check OK here, it's a mixin
# instance check OK here, it's a mixin if isinstance(h, _StreamingCallbackHandler) # type: ignore[misc]
if isinstance(h, _StreamingCallbackHandler) # type: ignore[misc] ),
), None,
None, ):
): # populates streamed_output in astream_log() output if needed
# populates streamed_output in astream_log() output if needed iterator = stream_handler.tap_output_iter(
iterator = stream_handler.tap_output_iter(run_manager.run_id, iterator) run_manager.run_id, iterator
try: )
while True: try:
chunk: Output = context.run(next, iterator) # type: ignore while True:
yield chunk chunk: Output = context.run(next, iterator) # type: ignore
if final_output_supported: yield chunk
if final_output is None: if final_output_supported:
if final_output is None:
final_output = chunk
else:
try:
final_output = final_output + chunk # type: ignore
except TypeError:
final_output = chunk
final_output_supported = False
else:
final_output = chunk final_output = chunk
except (StopIteration, GeneratorExit):
pass
for ichunk in input_for_tracing:
if final_input_supported:
if final_input is None:
final_input = ichunk
else: else:
try: try:
final_output = final_output + chunk # type: ignore final_input = final_input + ichunk # type: ignore
except TypeError: except TypeError:
final_output = chunk final_input = ichunk
final_output_supported = False final_input_supported = False
else: else:
final_output = chunk
except (StopIteration, GeneratorExit):
pass
for ichunk in input_for_tracing:
if final_input_supported:
if final_input is None:
final_input = ichunk final_input = ichunk
else:
try:
final_input = final_input + ichunk # type: ignore
except TypeError:
final_input = ichunk
final_input_supported = False
else:
final_input = ichunk
except BaseException as e: except BaseException as e:
run_manager.on_chain_error(e, inputs=final_input) run_manager.on_chain_error(e, inputs=final_input)
raise raise
@ -2283,60 +2281,59 @@ class Runnable(Generic[Input, Output], ABC):
kwargs["config"] = child_config kwargs["config"] = child_config
if accepts_run_manager(transformer): if accepts_run_manager(transformer):
kwargs["run_manager"] = run_manager kwargs["run_manager"] = run_manager
context = copy_context() with set_config_context(child_config) as context:
context.run(_set_config_context, child_config) iterator_ = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
iterator_ = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
if stream_handler := next( if stream_handler := next(
( (
cast(_StreamingCallbackHandler, h) cast(_StreamingCallbackHandler, h)
for h in run_manager.handlers for h in run_manager.handlers
# instance check OK here, it's a mixin # instance check OK here, it's a mixin
if isinstance(h, _StreamingCallbackHandler) # type: ignore[misc] if isinstance(h, _StreamingCallbackHandler) # type: ignore[misc]
), ),
None, None,
): ):
# populates streamed_output in astream_log() output if needed # populates streamed_output in astream_log() output if needed
iterator = stream_handler.tap_output_aiter( iterator = stream_handler.tap_output_aiter(
run_manager.run_id, iterator_ run_manager.run_id, iterator_
) )
else: else:
iterator = iterator_ iterator = iterator_
try: try:
while True: while True:
if asyncio_accepts_context(): if asyncio_accepts_context():
chunk: Output = await asyncio.create_task( # type: ignore[call-arg] chunk: Output = await asyncio.create_task( # type: ignore[call-arg]
py_anext(iterator), # type: ignore[arg-type] py_anext(iterator), # type: ignore[arg-type]
context=context, context=context,
) )
else: else:
chunk = cast(Output, await py_anext(iterator)) chunk = cast(Output, await py_anext(iterator))
yield chunk yield chunk
if final_output_supported: if final_output_supported:
if final_output is None: if final_output is None:
final_output = chunk
else:
try:
final_output = final_output + chunk # type: ignore
except TypeError:
final_output = chunk
final_output_supported = False
else:
final_output = chunk final_output = chunk
except StopAsyncIteration:
pass
async for ichunk in input_for_tracing:
if final_input_supported:
if final_input is None:
final_input = ichunk
else: else:
try: try:
final_output = final_output + chunk # type: ignore final_input = final_input + ichunk # type: ignore[operator]
except TypeError: except TypeError:
final_output = chunk final_input = ichunk
final_output_supported = False final_input_supported = False
else: else:
final_output = chunk
except StopAsyncIteration:
pass
async for ichunk in input_for_tracing:
if final_input_supported:
if final_input is None:
final_input = ichunk final_input = ichunk
else:
try:
final_input = final_input + ichunk # type: ignore[operator]
except TypeError:
final_input = ichunk
final_input_supported = False
else:
final_input = ichunk
except BaseException as e: except BaseException as e:
await run_manager.on_chain_error(e, inputs=final_input) await run_manager.on_chain_error(e, inputs=final_input)
raise raise
@ -3021,12 +3018,11 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
config = patch_config( config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i + 1}") config, callbacks=run_manager.get_child(f"seq:step:{i + 1}")
) )
context = copy_context() with set_config_context(config) as context:
context.run(_set_config_context, config) if i == 0:
if i == 0: input = context.run(step.invoke, input, config, **kwargs)
input = context.run(step.invoke, input, config, **kwargs) else:
else: input = context.run(step.invoke, input, config)
input = context.run(step.invoke, input, config)
# finish the root run # finish the root run
except BaseException as e: except BaseException as e:
run_manager.on_chain_error(e) run_manager.on_chain_error(e)
@ -3061,17 +3057,16 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
config = patch_config( config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i + 1}") config, callbacks=run_manager.get_child(f"seq:step:{i + 1}")
) )
context = copy_context() with set_config_context(config) as context:
context.run(_set_config_context, config) if i == 0:
if i == 0: part = functools.partial(step.ainvoke, input, config, **kwargs)
part = functools.partial(step.ainvoke, input, config, **kwargs) else:
else: part = functools.partial(step.ainvoke, input, config)
part = functools.partial(step.ainvoke, input, config) if asyncio_accepts_context():
if asyncio_accepts_context(): input = await asyncio.create_task(part(), context=context) # type: ignore
input = await asyncio.create_task(part(), context=context) # type: ignore else:
else: input = await asyncio.create_task(part())
input = await asyncio.create_task(part()) # finish the root run
# finish the root run
except BaseException as e: except BaseException as e:
await run_manager.on_chain_error(e) await run_manager.on_chain_error(e)
raise raise
@ -3713,13 +3708,12 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
# mark each step as a child run # mark each step as a child run
callbacks=run_manager.get_child(f"map:key:{key}"), callbacks=run_manager.get_child(f"map:key:{key}"),
) )
context = copy_context() with set_config_context(child_config) as context:
context.run(_set_config_context, child_config) return context.run(
return context.run( step.invoke,
step.invoke, input,
input, child_config,
child_config, )
)
# gather results from all steps # gather results from all steps
try: try:
@ -3764,14 +3758,13 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
config, config,
callbacks=run_manager.get_child(f"map:key:{key}"), callbacks=run_manager.get_child(f"map:key:{key}"),
) )
context = copy_context() with set_config_context(child_config) as context:
context.run(_set_config_context, child_config) if asyncio_accepts_context():
if asyncio_accepts_context(): return await asyncio.create_task( # type: ignore
return await asyncio.create_task( # type: ignore step.ainvoke(input, child_config), context=context
step.ainvoke(input, child_config), context=context )
) else:
else: return await asyncio.create_task(step.ainvoke(input, child_config))
return await asyncio.create_task(step.ainvoke(input, child_config))
# gather results from all steps # gather results from all steps
try: try:

View File

@ -6,7 +6,7 @@ import warnings
from collections.abc import Awaitable, Generator, Iterable, Iterator, Sequence from collections.abc import Awaitable, Generator, Iterable, Iterator, Sequence
from concurrent.futures import Executor, Future, ThreadPoolExecutor from concurrent.futures import Executor, Future, ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from contextvars import ContextVar, copy_context from contextvars import Context, ContextVar, Token, copy_context
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, cast from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, cast
@ -115,7 +115,10 @@ var_child_runnable_config: ContextVar[RunnableConfig | None] = ContextVar(
) )
def _set_config_context(config: RunnableConfig) -> None: # This is imported and used in langgraph, so don't break.
def _set_config_context(
config: RunnableConfig,
) -> tuple[Token[Optional[RunnableConfig]], Optional[dict[str, Any]]]:
"""Set the child Runnable config + tracing context. """Set the child Runnable config + tracing context.
Args: Args:
@ -123,7 +126,8 @@ def _set_config_context(config: RunnableConfig) -> None:
""" """
from langchain_core.tracers.langchain import LangChainTracer from langchain_core.tracers.langchain import LangChainTracer
var_child_runnable_config.set(config) config_token = var_child_runnable_config.set(config)
current_context = None
if ( if (
(callbacks := config.get("callbacks")) (callbacks := config.get("callbacks"))
and ( and (
@ -141,9 +145,32 @@ def _set_config_context(config: RunnableConfig) -> None:
) )
and (run := tracer.run_map.get(str(parent_run_id))) and (run := tracer.run_map.get(str(parent_run_id)))
): ):
from langsmith.run_helpers import _set_tracing_context from langsmith.run_helpers import _set_tracing_context, get_tracing_context
current_context = get_tracing_context()
_set_tracing_context({"parent": run}) _set_tracing_context({"parent": run})
return config_token, current_context
@contextmanager
def set_config_context(
config: RunnableConfig, ctx: Optional[Context] = None
) -> Generator[Context, None, None]:
"""Set the child Runnable config + tracing context.
Args:
config (RunnableConfig): The config to set.
"""
from langsmith.run_helpers import _set_tracing_context
ctx = ctx if ctx is not None else copy_context()
config_token, current_context = ctx.run(_set_config_context, config)
try:
yield ctx
finally:
ctx.run(var_child_runnable_config.reset, config_token)
if current_context is not None:
ctx.run(_set_tracing_context, current_context)
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:

View File

@ -2,7 +2,6 @@ import asyncio
import inspect import inspect
import typing import typing
from collections.abc import AsyncIterator, Iterator, Sequence from collections.abc import AsyncIterator, Iterator, Sequence
from contextvars import copy_context
from functools import wraps from functools import wraps
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -18,12 +17,12 @@ from typing_extensions import override
from langchain_core.runnables.base import Runnable, RunnableSerializable from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.runnables.config import ( from langchain_core.runnables.config import (
RunnableConfig, RunnableConfig,
_set_config_context,
ensure_config, ensure_config,
get_async_callback_manager_for_config, get_async_callback_manager_for_config,
get_callback_manager_for_config, get_callback_manager_for_config,
get_config_list, get_config_list,
patch_config, patch_config,
set_config_context,
) )
from langchain_core.runnables.utils import ( from langchain_core.runnables.utils import (
ConfigurableFieldSpec, ConfigurableFieldSpec,
@ -174,14 +173,13 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
if self.exception_key and last_error is not None: if self.exception_key and last_error is not None:
input[self.exception_key] = last_error input[self.exception_key] = last_error
child_config = patch_config(config, callbacks=run_manager.get_child()) child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context() with set_config_context(child_config) as context:
context.run(_set_config_context, child_config) output = context.run(
output = context.run( runnable.invoke,
runnable.invoke, input,
input, config,
config, **kwargs,
**kwargs, )
)
except self.exceptions_to_handle as e: except self.exceptions_to_handle as e:
if first_error is None: if first_error is None:
first_error = e first_error = e
@ -228,13 +226,12 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
if self.exception_key and last_error is not None: if self.exception_key and last_error is not None:
input[self.exception_key] = last_error input[self.exception_key] = last_error
child_config = patch_config(config, callbacks=run_manager.get_child()) child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context() with set_config_context(child_config) as context:
context.run(_set_config_context, child_config) coro = context.run(runnable.ainvoke, input, config, **kwargs)
coro = runnable.ainvoke(input, child_config, **kwargs) if asyncio_accepts_context():
if asyncio_accepts_context(): output = await asyncio.create_task(coro, context=context) # type: ignore
output = await asyncio.create_task(coro, context=context) # type: ignore else:
else: output = await coro
output = await coro
except self.exceptions_to_handle as e: except self.exceptions_to_handle as e:
if first_error is None: if first_error is None:
first_error = e first_error = e
@ -475,14 +472,13 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
if self.exception_key and last_error is not None: if self.exception_key and last_error is not None:
input[self.exception_key] = last_error input[self.exception_key] = last_error
child_config = patch_config(config, callbacks=run_manager.get_child()) child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context() with set_config_context(child_config) as context:
context.run(_set_config_context, child_config) stream = context.run(
stream = context.run( runnable.stream,
runnable.stream, input,
input, **kwargs,
**kwargs, )
) chunk: Output = context.run(next, stream) # type: ignore
chunk: Output = context.run(next, stream) # type: ignore
except self.exceptions_to_handle as e: except self.exceptions_to_handle as e:
first_error = e if first_error is None else first_error first_error = e if first_error is None else first_error
last_error = e last_error = e
@ -539,20 +535,19 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
if self.exception_key and last_error is not None: if self.exception_key and last_error is not None:
input[self.exception_key] = last_error input[self.exception_key] = last_error
child_config = patch_config(config, callbacks=run_manager.get_child()) child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context() with set_config_context(child_config) as context:
context.run(_set_config_context, child_config) stream = runnable.astream(
stream = runnable.astream( input,
input, child_config,
child_config, **kwargs,
**kwargs,
)
if asyncio_accepts_context():
chunk: Output = await asyncio.create_task( # type: ignore[call-arg]
py_anext(stream), # type: ignore[arg-type]
context=context,
) )
else: if asyncio_accepts_context():
chunk = cast(Output, await py_anext(stream)) chunk: Output = await asyncio.create_task( # type: ignore[call-arg]
py_anext(stream), # type: ignore[arg-type]
context=context,
)
else:
chunk = cast(Output, await py_anext(stream))
except self.exceptions_to_handle as e: except self.exceptions_to_handle as e:
first_error = e if first_error is None else first_error first_error = e if first_error is None else first_error
last_error = e last_error = e

View File

@ -6,7 +6,6 @@ import inspect
import json import json
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextvars import copy_context
from inspect import signature from inspect import signature
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -52,7 +51,7 @@ from langchain_core.runnables import (
patch_config, patch_config,
run_in_executor, run_in_executor,
) )
from langchain_core.runnables.config import _set_config_context from langchain_core.runnables.config import set_config_context
from langchain_core.runnables.utils import asyncio_accepts_context from langchain_core.runnables.utils import asyncio_accepts_context
from langchain_core.utils.function_calling import ( from langchain_core.utils.function_calling import (
_parse_google_docstring, _parse_google_docstring,
@ -722,14 +721,15 @@ class ChildTool(BaseTool):
error_to_raise: Union[Exception, KeyboardInterrupt, None] = None error_to_raise: Union[Exception, KeyboardInterrupt, None] = None
try: try:
child_config = patch_config(config, callbacks=run_manager.get_child()) child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context() with set_config_context(child_config) as context:
context.run(_set_config_context, child_config) tool_args, tool_kwargs = self._to_args_and_kwargs(
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id) tool_input, tool_call_id
if signature(self._run).parameters.get("run_manager"): )
tool_kwargs = tool_kwargs | {"run_manager": run_manager} if signature(self._run).parameters.get("run_manager"):
if config_param := _get_runnable_config_param(self._run): tool_kwargs = tool_kwargs | {"run_manager": run_manager}
tool_kwargs = tool_kwargs | {config_param: config} if config_param := _get_runnable_config_param(self._run):
response = context.run(self._run, *tool_args, **tool_kwargs) tool_kwargs = tool_kwargs | {config_param: config}
response = context.run(self._run, *tool_args, **tool_kwargs)
if self.response_format == "content_and_artifact": if self.response_format == "content_and_artifact":
if not isinstance(response, tuple) or len(response) != 2: if not isinstance(response, tuple) or len(response) != 2:
msg = ( msg = (
@ -832,21 +832,20 @@ class ChildTool(BaseTool):
try: try:
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id) tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id)
child_config = patch_config(config, callbacks=run_manager.get_child()) child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context() with set_config_context(child_config) as context:
context.run(_set_config_context, child_config) func_to_check = (
func_to_check = ( self._run if self.__class__._arun is BaseTool._arun else self._arun
self._run if self.__class__._arun is BaseTool._arun else self._arun )
) if signature(func_to_check).parameters.get("run_manager"):
if signature(func_to_check).parameters.get("run_manager"): tool_kwargs["run_manager"] = run_manager
tool_kwargs["run_manager"] = run_manager if config_param := _get_runnable_config_param(func_to_check):
if config_param := _get_runnable_config_param(func_to_check): tool_kwargs[config_param] = config
tool_kwargs[config_param] = config
coro = context.run(self._arun, *tool_args, **tool_kwargs) coro = self._arun(*tool_args, **tool_kwargs)
if asyncio_accepts_context(): if asyncio_accepts_context():
response = await asyncio.create_task(coro, context=context) # type: ignore response = await asyncio.create_task(coro, context=context) # type: ignore
else: else:
response = await coro response = await coro
if self.response_format == "content_and_artifact": if self.response_format == "content_and_artifact":
if not isinstance(response, tuple) or len(response) != 2: if not isinstance(response, tuple) or len(response) != 2:
msg = ( msg = (

View File

@ -89,11 +89,11 @@ def tracing_v2_enabled(
tags=tags, tags=tags,
client=client, client=client,
) )
token = tracing_v2_callback_var.set(cb)
try: try:
tracing_v2_callback_var.set(cb)
yield cb yield cb
finally: finally:
tracing_v2_callback_var.set(None) tracing_v2_callback_var.reset(token)
@contextmanager @contextmanager
@ -109,9 +109,11 @@ def collect_runs() -> Generator[RunCollectorCallbackHandler, None, None]:
run_id = runs_cb.traced_runs[0].id run_id = runs_cb.traced_runs[0].id
""" """
cb = RunCollectorCallbackHandler() cb = RunCollectorCallbackHandler()
run_collector_var.set(cb) token = run_collector_var.set(cb)
yield cb try:
run_collector_var.set(None) yield cb
finally:
run_collector_var.reset(token)
def _get_trace_callbacks( def _get_trace_callbacks(