diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 40ea26d3193..4fbb1f96201 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -17,7 +17,6 @@ from collections.abc import ( Sequence, ) from concurrent.futures import FIRST_COMPLETED, wait -from contextvars import copy_context from functools import wraps from itertools import groupby, tee from operator import itemgetter @@ -47,7 +46,6 @@ from langchain_core.load.serializable import ( ) from langchain_core.runnables.config import ( RunnableConfig, - _set_config_context, acall_func_with_variable_args, call_func_with_variable_args, ensure_config, @@ -58,6 +56,7 @@ from langchain_core.runnables.config import ( merge_configs, patch_config, run_in_executor, + set_config_context, ) from langchain_core.runnables.graph import Graph from langchain_core.runnables.utils import ( @@ -1920,19 +1919,18 @@ class Runnable(Generic[Input, Output], ABC): ) try: child_config = patch_config(config, callbacks=run_manager.get_child()) - context = copy_context() - context.run(_set_config_context, child_config) - output = cast( - Output, - context.run( - call_func_with_variable_args, # type: ignore[arg-type] - func, # type: ignore[arg-type] - input, # type: ignore[arg-type] - config, - run_manager, - **kwargs, - ), - ) + with set_config_context(child_config) as context: + output = cast( + Output, + context.run( + call_func_with_variable_args, # type: ignore[arg-type] + func, # type: ignore[arg-type] + input, # type: ignore[arg-type] + config, + run_manager, + **kwargs, + ), + ) except BaseException as e: run_manager.on_chain_error(e) raise @@ -1970,15 +1968,14 @@ class Runnable(Generic[Input, Output], ABC): ) try: child_config = patch_config(config, callbacks=run_manager.get_child()) - context = copy_context() - context.run(_set_config_context, child_config) - coro = acall_func_with_variable_args( - func, input, config, run_manager, **kwargs - ) - if asyncio_accepts_context(): - output: Output = await asyncio.create_task(coro, context=context) # type: ignore - else: - output = await coro + with set_config_context(child_config) as context: + coro = acall_func_with_variable_args( + func, input, config, run_manager, **kwargs + ) + if asyncio_accepts_context(): + output: Output = await asyncio.create_task(coro, context=context) # type: ignore + else: + output = await coro except BaseException as e: await run_manager.on_chain_error(e) raise @@ -2182,49 +2179,50 @@ class Runnable(Generic[Input, Output], ABC): kwargs["config"] = child_config if accepts_run_manager(transformer): kwargs["run_manager"] = run_manager - context = copy_context() - context.run(_set_config_context, child_config) - iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type] - if stream_handler := next( - ( - cast(_StreamingCallbackHandler, h) - for h in run_manager.handlers - # instance check OK here, it's a mixin - if isinstance(h, _StreamingCallbackHandler) # type: ignore[misc] - ), - None, - ): - # populates streamed_output in astream_log() output if needed - iterator = stream_handler.tap_output_iter(run_manager.run_id, iterator) - try: - while True: - chunk: Output = context.run(next, iterator) # type: ignore - yield chunk - if final_output_supported: - if final_output is None: + with set_config_context(child_config) as context: + iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type] + if stream_handler := next( + ( + cast(_StreamingCallbackHandler, h) + for h in run_manager.handlers + # instance check OK here, it's a mixin + if isinstance(h, _StreamingCallbackHandler) # type: ignore[misc] + ), + None, + ): + # populates streamed_output in astream_log() output if needed + iterator = stream_handler.tap_output_iter( + run_manager.run_id, iterator + ) + try: + while True: + chunk: Output = context.run(next, iterator) # type: ignore + yield chunk + if final_output_supported: + if final_output is None: + final_output = chunk + else: + try: + final_output = final_output + chunk # type: ignore + except TypeError: + final_output = chunk + final_output_supported = False + else: final_output = chunk + except (StopIteration, GeneratorExit): + pass + for ichunk in input_for_tracing: + if final_input_supported: + if final_input is None: + final_input = ichunk else: try: - final_output = final_output + chunk # type: ignore + final_input = final_input + ichunk # type: ignore except TypeError: - final_output = chunk - final_output_supported = False + final_input = ichunk + final_input_supported = False else: - final_output = chunk - except (StopIteration, GeneratorExit): - pass - for ichunk in input_for_tracing: - if final_input_supported: - if final_input is None: final_input = ichunk - else: - try: - final_input = final_input + ichunk # type: ignore - except TypeError: - final_input = ichunk - final_input_supported = False - else: - final_input = ichunk except BaseException as e: run_manager.on_chain_error(e, inputs=final_input) raise @@ -2283,60 +2281,59 @@ class Runnable(Generic[Input, Output], ABC): kwargs["config"] = child_config if accepts_run_manager(transformer): kwargs["run_manager"] = run_manager - context = copy_context() - context.run(_set_config_context, child_config) - iterator_ = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type] + with set_config_context(child_config) as context: + iterator_ = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type] - if stream_handler := next( - ( - cast(_StreamingCallbackHandler, h) - for h in run_manager.handlers - # instance check OK here, it's a mixin - if isinstance(h, _StreamingCallbackHandler) # type: ignore[misc] - ), - None, - ): - # populates streamed_output in astream_log() output if needed - iterator = stream_handler.tap_output_aiter( - run_manager.run_id, iterator_ - ) - else: - iterator = iterator_ - try: - while True: - if asyncio_accepts_context(): - chunk: Output = await asyncio.create_task( # type: ignore[call-arg] - py_anext(iterator), # type: ignore[arg-type] - context=context, - ) - else: - chunk = cast(Output, await py_anext(iterator)) - yield chunk - if final_output_supported: - if final_output is None: + if stream_handler := next( + ( + cast(_StreamingCallbackHandler, h) + for h in run_manager.handlers + # instance check OK here, it's a mixin + if isinstance(h, _StreamingCallbackHandler) # type: ignore[misc] + ), + None, + ): + # populates streamed_output in astream_log() output if needed + iterator = stream_handler.tap_output_aiter( + run_manager.run_id, iterator_ + ) + else: + iterator = iterator_ + try: + while True: + if asyncio_accepts_context(): + chunk: Output = await asyncio.create_task( # type: ignore[call-arg] + py_anext(iterator), # type: ignore[arg-type] + context=context, + ) + else: + chunk = cast(Output, await py_anext(iterator)) + yield chunk + if final_output_supported: + if final_output is None: + final_output = chunk + else: + try: + final_output = final_output + chunk # type: ignore + except TypeError: + final_output = chunk + final_output_supported = False + else: final_output = chunk + except StopAsyncIteration: + pass + async for ichunk in input_for_tracing: + if final_input_supported: + if final_input is None: + final_input = ichunk else: try: - final_output = final_output + chunk # type: ignore + final_input = final_input + ichunk # type: ignore[operator] except TypeError: - final_output = chunk - final_output_supported = False + final_input = ichunk + final_input_supported = False else: - final_output = chunk - except StopAsyncIteration: - pass - async for ichunk in input_for_tracing: - if final_input_supported: - if final_input is None: final_input = ichunk - else: - try: - final_input = final_input + ichunk # type: ignore[operator] - except TypeError: - final_input = ichunk - final_input_supported = False - else: - final_input = ichunk except BaseException as e: await run_manager.on_chain_error(e, inputs=final_input) raise @@ -3021,12 +3018,11 @@ class RunnableSequence(RunnableSerializable[Input, Output]): config = patch_config( config, callbacks=run_manager.get_child(f"seq:step:{i + 1}") ) - context = copy_context() - context.run(_set_config_context, config) - if i == 0: - input = context.run(step.invoke, input, config, **kwargs) - else: - input = context.run(step.invoke, input, config) + with set_config_context(config) as context: + if i == 0: + input = context.run(step.invoke, input, config, **kwargs) + else: + input = context.run(step.invoke, input, config) # finish the root run except BaseException as e: run_manager.on_chain_error(e) @@ -3061,17 +3057,16 @@ class RunnableSequence(RunnableSerializable[Input, Output]): config = patch_config( config, callbacks=run_manager.get_child(f"seq:step:{i + 1}") ) - context = copy_context() - context.run(_set_config_context, config) - if i == 0: - part = functools.partial(step.ainvoke, input, config, **kwargs) - else: - part = functools.partial(step.ainvoke, input, config) - if asyncio_accepts_context(): - input = await asyncio.create_task(part(), context=context) # type: ignore - else: - input = await asyncio.create_task(part()) - # finish the root run + with set_config_context(config) as context: + if i == 0: + part = functools.partial(step.ainvoke, input, config, **kwargs) + else: + part = functools.partial(step.ainvoke, input, config) + if asyncio_accepts_context(): + input = await asyncio.create_task(part(), context=context) # type: ignore + else: + input = await asyncio.create_task(part()) + # finish the root run except BaseException as e: await run_manager.on_chain_error(e) raise @@ -3713,13 +3708,12 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): # mark each step as a child run callbacks=run_manager.get_child(f"map:key:{key}"), ) - context = copy_context() - context.run(_set_config_context, child_config) - return context.run( - step.invoke, - input, - child_config, - ) + with set_config_context(child_config) as context: + return context.run( + step.invoke, + input, + child_config, + ) # gather results from all steps try: @@ -3764,14 +3758,13 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): config, callbacks=run_manager.get_child(f"map:key:{key}"), ) - context = copy_context() - context.run(_set_config_context, child_config) - if asyncio_accepts_context(): - return await asyncio.create_task( # type: ignore - step.ainvoke(input, child_config), context=context - ) - else: - return await asyncio.create_task(step.ainvoke(input, child_config)) + with set_config_context(child_config) as context: + if asyncio_accepts_context(): + return await asyncio.create_task( # type: ignore + step.ainvoke(input, child_config), context=context + ) + else: + return await asyncio.create_task(step.ainvoke(input, child_config)) # gather results from all steps try: diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py index 126deb86701..63328a05e3a 100644 --- a/libs/core/langchain_core/runnables/config.py +++ b/libs/core/langchain_core/runnables/config.py @@ -6,7 +6,7 @@ import warnings from collections.abc import Awaitable, Generator, Iterable, Iterator, Sequence from concurrent.futures import Executor, Future, ThreadPoolExecutor from contextlib import contextmanager -from contextvars import ContextVar, copy_context +from contextvars import Context, ContextVar, Token, copy_context from functools import partial from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, cast @@ -115,7 +115,10 @@ var_child_runnable_config: ContextVar[RunnableConfig | None] = ContextVar( ) -def _set_config_context(config: RunnableConfig) -> None: +# This is imported and used in langgraph, so don't break. +def _set_config_context( + config: RunnableConfig, +) -> tuple[Token[Optional[RunnableConfig]], Optional[dict[str, Any]]]: """Set the child Runnable config + tracing context. Args: @@ -123,7 +126,8 @@ def _set_config_context(config: RunnableConfig) -> None: """ from langchain_core.tracers.langchain import LangChainTracer - var_child_runnable_config.set(config) + config_token = var_child_runnable_config.set(config) + current_context = None if ( (callbacks := config.get("callbacks")) and ( @@ -141,9 +145,32 @@ def _set_config_context(config: RunnableConfig) -> None: ) and (run := tracer.run_map.get(str(parent_run_id))) ): - from langsmith.run_helpers import _set_tracing_context + from langsmith.run_helpers import _set_tracing_context, get_tracing_context + current_context = get_tracing_context() _set_tracing_context({"parent": run}) + return config_token, current_context + + +@contextmanager +def set_config_context( + config: RunnableConfig, ctx: Optional[Context] = None +) -> Generator[Context, None, None]: + """Set the child Runnable config + tracing context. + + Args: + config (RunnableConfig): The config to set. + """ + from langsmith.run_helpers import _set_tracing_context + + ctx = ctx if ctx is not None else copy_context() + config_token, current_context = ctx.run(_set_config_context, config) + try: + yield ctx + finally: + ctx.run(var_child_runnable_config.reset, config_token) + if current_context is not None: + ctx.run(_set_tracing_context, current_context) def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: diff --git a/libs/core/langchain_core/runnables/fallbacks.py b/libs/core/langchain_core/runnables/fallbacks.py index f932ce3589e..20bfa7e7528 100644 --- a/libs/core/langchain_core/runnables/fallbacks.py +++ b/libs/core/langchain_core/runnables/fallbacks.py @@ -2,7 +2,6 @@ import asyncio import inspect import typing from collections.abc import AsyncIterator, Iterator, Sequence -from contextvars import copy_context from functools import wraps from typing import ( TYPE_CHECKING, @@ -18,12 +17,12 @@ from typing_extensions import override from langchain_core.runnables.base import Runnable, RunnableSerializable from langchain_core.runnables.config import ( RunnableConfig, - _set_config_context, ensure_config, get_async_callback_manager_for_config, get_callback_manager_for_config, get_config_list, patch_config, + set_config_context, ) from langchain_core.runnables.utils import ( ConfigurableFieldSpec, @@ -174,14 +173,13 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): if self.exception_key and last_error is not None: input[self.exception_key] = last_error child_config = patch_config(config, callbacks=run_manager.get_child()) - context = copy_context() - context.run(_set_config_context, child_config) - output = context.run( - runnable.invoke, - input, - config, - **kwargs, - ) + with set_config_context(child_config) as context: + output = context.run( + runnable.invoke, + input, + config, + **kwargs, + ) except self.exceptions_to_handle as e: if first_error is None: first_error = e @@ -228,13 +226,12 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): if self.exception_key and last_error is not None: input[self.exception_key] = last_error child_config = patch_config(config, callbacks=run_manager.get_child()) - context = copy_context() - context.run(_set_config_context, child_config) - coro = runnable.ainvoke(input, child_config, **kwargs) - if asyncio_accepts_context(): - output = await asyncio.create_task(coro, context=context) # type: ignore - else: - output = await coro + with set_config_context(child_config) as context: + coro = context.run(runnable.ainvoke, input, config, **kwargs) + if asyncio_accepts_context(): + output = await asyncio.create_task(coro, context=context) # type: ignore + else: + output = await coro except self.exceptions_to_handle as e: if first_error is None: first_error = e @@ -475,14 +472,13 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): if self.exception_key and last_error is not None: input[self.exception_key] = last_error child_config = patch_config(config, callbacks=run_manager.get_child()) - context = copy_context() - context.run(_set_config_context, child_config) - stream = context.run( - runnable.stream, - input, - **kwargs, - ) - chunk: Output = context.run(next, stream) # type: ignore + with set_config_context(child_config) as context: + stream = context.run( + runnable.stream, + input, + **kwargs, + ) + chunk: Output = context.run(next, stream) # type: ignore except self.exceptions_to_handle as e: first_error = e if first_error is None else first_error last_error = e @@ -539,20 +535,19 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): if self.exception_key and last_error is not None: input[self.exception_key] = last_error child_config = patch_config(config, callbacks=run_manager.get_child()) - context = copy_context() - context.run(_set_config_context, child_config) - stream = runnable.astream( - input, - child_config, - **kwargs, - ) - if asyncio_accepts_context(): - chunk: Output = await asyncio.create_task( # type: ignore[call-arg] - py_anext(stream), # type: ignore[arg-type] - context=context, + with set_config_context(child_config) as context: + stream = runnable.astream( + input, + child_config, + **kwargs, ) - else: - chunk = cast(Output, await py_anext(stream)) + if asyncio_accepts_context(): + chunk: Output = await asyncio.create_task( # type: ignore[call-arg] + py_anext(stream), # type: ignore[arg-type] + context=context, + ) + else: + chunk = cast(Output, await py_anext(stream)) except self.exceptions_to_handle as e: first_error = e if first_error is None else first_error last_error = e diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 5e1a8a47523..87a8c80b865 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -6,7 +6,6 @@ import inspect import json import warnings from abc import ABC, abstractmethod -from contextvars import copy_context from inspect import signature from typing import ( TYPE_CHECKING, @@ -52,7 +51,7 @@ from langchain_core.runnables import ( patch_config, run_in_executor, ) -from langchain_core.runnables.config import _set_config_context +from langchain_core.runnables.config import set_config_context from langchain_core.runnables.utils import asyncio_accepts_context from langchain_core.utils.function_calling import ( _parse_google_docstring, @@ -722,14 +721,15 @@ class ChildTool(BaseTool): error_to_raise: Union[Exception, KeyboardInterrupt, None] = None try: child_config = patch_config(config, callbacks=run_manager.get_child()) - context = copy_context() - context.run(_set_config_context, child_config) - tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id) - if signature(self._run).parameters.get("run_manager"): - tool_kwargs = tool_kwargs | {"run_manager": run_manager} - if config_param := _get_runnable_config_param(self._run): - tool_kwargs = tool_kwargs | {config_param: config} - response = context.run(self._run, *tool_args, **tool_kwargs) + with set_config_context(child_config) as context: + tool_args, tool_kwargs = self._to_args_and_kwargs( + tool_input, tool_call_id + ) + if signature(self._run).parameters.get("run_manager"): + tool_kwargs = tool_kwargs | {"run_manager": run_manager} + if config_param := _get_runnable_config_param(self._run): + tool_kwargs = tool_kwargs | {config_param: config} + response = context.run(self._run, *tool_args, **tool_kwargs) if self.response_format == "content_and_artifact": if not isinstance(response, tuple) or len(response) != 2: msg = ( @@ -832,21 +832,20 @@ class ChildTool(BaseTool): try: tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input, tool_call_id) child_config = patch_config(config, callbacks=run_manager.get_child()) - context = copy_context() - context.run(_set_config_context, child_config) - func_to_check = ( - self._run if self.__class__._arun is BaseTool._arun else self._arun - ) - if signature(func_to_check).parameters.get("run_manager"): - tool_kwargs["run_manager"] = run_manager - if config_param := _get_runnable_config_param(func_to_check): - tool_kwargs[config_param] = config + with set_config_context(child_config) as context: + func_to_check = ( + self._run if self.__class__._arun is BaseTool._arun else self._arun + ) + if signature(func_to_check).parameters.get("run_manager"): + tool_kwargs["run_manager"] = run_manager + if config_param := _get_runnable_config_param(func_to_check): + tool_kwargs[config_param] = config - coro = context.run(self._arun, *tool_args, **tool_kwargs) - if asyncio_accepts_context(): - response = await asyncio.create_task(coro, context=context) # type: ignore - else: - response = await coro + coro = self._arun(*tool_args, **tool_kwargs) + if asyncio_accepts_context(): + response = await asyncio.create_task(coro, context=context) # type: ignore + else: + response = await coro if self.response_format == "content_and_artifact": if not isinstance(response, tuple) or len(response) != 2: msg = ( diff --git a/libs/core/langchain_core/tracers/context.py b/libs/core/langchain_core/tracers/context.py index a947f162fcd..5e516a7817d 100644 --- a/libs/core/langchain_core/tracers/context.py +++ b/libs/core/langchain_core/tracers/context.py @@ -89,11 +89,11 @@ def tracing_v2_enabled( tags=tags, client=client, ) + token = tracing_v2_callback_var.set(cb) try: - tracing_v2_callback_var.set(cb) yield cb finally: - tracing_v2_callback_var.set(None) + tracing_v2_callback_var.reset(token) @contextmanager @@ -109,9 +109,11 @@ def collect_runs() -> Generator[RunCollectorCallbackHandler, None, None]: run_id = runs_cb.traced_runs[0].id """ cb = RunCollectorCallbackHandler() - run_collector_var.set(cb) - yield cb - run_collector_var.set(None) + token = run_collector_var.set(cb) + try: + yield cb + finally: + run_collector_var.reset(token) def _get_trace_callbacks(