diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 357a135adf8..619d86368d7 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import collections +import functools import inspect import threading from abc import ABC, abstractmethod @@ -68,8 +69,8 @@ from langchain_core.runnables.utils import ( Input, Output, accepts_config, - accepts_context, accepts_run_manager, + asyncio_accepts_context, create_model, gather_with_concurrency, get_function_first_arg_dict_keys, @@ -1830,7 +1831,7 @@ class Runnable(Generic[Input, Output], ABC): coro = acall_func_with_variable_args( func, input, config, run_manager, **kwargs ) - if accepts_context(asyncio.create_task): + if asyncio_accepts_context(): output: Output = await asyncio.create_task(coro, context=context) # type: ignore else: output = await coro @@ -2156,7 +2157,7 @@ class Runnable(Generic[Input, Output], ABC): iterator = iterator_ try: while True: - if accepts_context(asyncio.create_task): + if asyncio_accepts_context(): chunk: Output = await asyncio.create_task( # type: ignore[call-arg] py_anext(iterator), # type: ignore[arg-type] context=context, @@ -2869,10 +2870,12 @@ class RunnableSequence(RunnableSerializable[Input, Output]): config = patch_config( config, callbacks=run_manager.get_child(f"seq:step:{i+1}") ) + context = copy_context() + context.run(_set_config_context, config) if i == 0: - input = step.invoke(input, config, **kwargs) + input = context.run(step.invoke, input, config, **kwargs) else: - input = step.invoke(input, config) + input = context.run(step.invoke, input, config) # finish the root run except BaseException as e: run_manager.on_chain_error(e) @@ -2907,10 +2910,16 @@ class RunnableSequence(RunnableSerializable[Input, Output]): config = patch_config( config, callbacks=run_manager.get_child(f"seq:step:{i+1}") ) + context = copy_context() + context.run(_set_config_context, config) if i == 0: - input = await step.ainvoke(input, config, **kwargs) + part = functools.partial(step.ainvoke, input, config, **kwargs) else: - input = await step.ainvoke(input, config) + part = functools.partial(step.ainvoke, input, config) + if asyncio_accepts_context(): + input = await asyncio.create_task(part(), context=context) # type: ignore + else: + input = await asyncio.create_task(part()) # finish the root run except BaseException as e: await run_manager.on_chain_error(e) @@ -3542,21 +3551,30 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): run_id=config.pop("run_id", None), ) + def _invoke_step( + step: Runnable[Input, Any], input: Input, config: RunnableConfig, key: str + ) -> Any: + child_config = patch_config( + config, + # mark each step as a child run + callbacks=run_manager.get_child(f"map:key:{key}"), + ) + context = copy_context() + context.run(_set_config_context, child_config) + return context.run( + step.invoke, + input, + child_config, + ) + # gather results from all steps try: # copy to avoid issues from the caller mutating the steps during invoke() steps = dict(self.steps__) + with get_executor_for_config(config) as executor: futures = [ - executor.submit( - step.invoke, - input, - # mark each step as a child run - patch_config( - config, - callbacks=run_manager.get_child(f"map:key:{key}"), - ), - ) + executor.submit(_invoke_step, step, input, config, key) for key, step in steps.items() ] output = {key: future.result() for key, future in zip(steps, futures)} @@ -3585,18 +3603,34 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): run_id=config.pop("run_id", None), ) + async def _ainvoke_step( + step: Runnable[Input, Any], input: Input, config: RunnableConfig, key: str + ) -> Any: + child_config = patch_config( + config, + callbacks=run_manager.get_child(f"map:key:{key}"), + ) + context = copy_context() + context.run(_set_config_context, child_config) + if asyncio_accepts_context(): + return await asyncio.create_task( # type: ignore + step.ainvoke(input, child_config), context=context + ) + else: + return await asyncio.create_task(step.ainvoke(input, child_config)) + # gather results from all steps try: # copy to avoid issues from the caller mutating the steps during invoke() steps = dict(self.steps__) results = await asyncio.gather( *( - step.ainvoke( + _ainvoke_step( + step, input, # mark each step as a child run - patch_config( - config, callbacks=run_manager.get_child(f"map:key:{key}") - ), + config, + key, ) for key, step in steps.items() ) diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py index 0d71b7fa9f7..da6aa98a4b6 100644 --- a/libs/core/langchain_core/runnables/config.py +++ b/libs/core/langchain_core/runnables/config.py @@ -348,7 +348,6 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig: base["callbacks"] = mngr else: # base_callbacks is also a manager - manager = base_callbacks.__class__( parent_run_id=base_callbacks.parent_run_id or these_callbacks.parent_run_id, diff --git a/libs/core/langchain_core/runnables/fallbacks.py b/libs/core/langchain_core/runnables/fallbacks.py index 5187e299c1f..b3249b47cc4 100644 --- a/libs/core/langchain_core/runnables/fallbacks.py +++ b/libs/core/langchain_core/runnables/fallbacks.py @@ -1,12 +1,12 @@ import asyncio import inspect import typing +from contextvars import copy_context from functools import wraps from typing import ( TYPE_CHECKING, Any, AsyncIterator, - Awaitable, Dict, Iterator, List, @@ -23,6 +23,7 @@ from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables.base import Runnable, RunnableSerializable from langchain_core.runnables.config import ( RunnableConfig, + _set_config_context, ensure_config, get_async_callback_manager_for_config, get_callback_manager_for_config, @@ -33,6 +34,7 @@ from langchain_core.runnables.utils import ( ConfigurableFieldSpec, Input, Output, + asyncio_accepts_context, get_unique_config_specs, ) from langchain_core.utils.aiter import py_anext @@ -172,9 +174,12 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): try: if self.exception_key and last_error is not None: input[self.exception_key] = last_error - output = runnable.invoke( + child_config = patch_config(config, callbacks=run_manager.get_child()) + context = copy_context() + context.run(_set_config_context, child_config) + output = context.run( + runnable.invoke, input, - patch_config(config, callbacks=run_manager.get_child()), **kwargs, ) except self.exceptions_to_handle as e: @@ -220,11 +225,14 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): try: if self.exception_key and last_error is not None: input[self.exception_key] = last_error - output = await runnable.ainvoke( - input, - patch_config(config, callbacks=run_manager.get_child()), - **kwargs, - ) + child_config = patch_config(config, callbacks=run_manager.get_child()) + context = copy_context() + context.run(_set_config_context, child_config) + coro = runnable.ainvoke(input, child_config, **kwargs) + if asyncio_accepts_context(): + output = await asyncio.create_task(coro, context=context) # type: ignore + else: + output = await coro except self.exceptions_to_handle as e: if first_error is None: first_error = e @@ -460,12 +468,15 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): try: if self.exception_key and last_error is not None: input[self.exception_key] = last_error - stream = runnable.stream( + child_config = patch_config(config, callbacks=run_manager.get_child()) + context = copy_context() + context.run(_set_config_context, child_config) + stream = context.run( + runnable.stream, input, - patch_config(config, callbacks=run_manager.get_child()), **kwargs, ) - chunk = next(stream) + chunk: Output = context.run(next, stream) # type: ignore except self.exceptions_to_handle as e: first_error = e if first_error is None else first_error last_error = e @@ -520,12 +531,21 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): try: if self.exception_key and last_error is not None: input[self.exception_key] = last_error + child_config = patch_config(config, callbacks=run_manager.get_child()) + context = copy_context() + context.run(_set_config_context, child_config) stream = runnable.astream( input, - patch_config(config, callbacks=run_manager.get_child()), + child_config, **kwargs, ) - chunk = await cast(Awaitable[Output], py_anext(stream)) + if asyncio_accepts_context(): + chunk: Output = await asyncio.create_task( # type: ignore[call-arg] + py_anext(stream), # type: ignore[arg-type] + context=context, + ) + else: + chunk = cast(Output, await py_anext(stream)) except self.exceptions_to_handle as e: first_error = e if first_error is None else first_error last_error = e diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index e40c5da89e6..1d71f47d1bd 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -118,6 +118,11 @@ def accepts_context(callable: Callable[..., Any]) -> bool: return False +@lru_cache(maxsize=1) +def asyncio_accepts_context() -> bool: + return accepts_context(asyncio.create_task) + + class IsLocalDict(ast.NodeVisitor): """Check if a name is a local dict.""" diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index 94ceb0d7090..729d68ffc3b 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -87,7 +87,7 @@ from langchain_core.runnables.config import ( patch_config, run_in_executor, ) -from langchain_core.runnables.utils import accepts_context +from langchain_core.runnables.utils import asyncio_accepts_context from langchain_core.utils.function_calling import ( _parse_google_docstring, _py_38_safe_origin, @@ -694,7 +694,7 @@ class ChildTool(BaseTool): tool_kwargs[config_param] = config coro = context.run(self._arun, *tool_args, **tool_kwargs) - if accepts_context(asyncio.create_task): + if asyncio_accepts_context(): response = await asyncio.create_task(coro, context=context) # type: ignore else: response = await coro diff --git a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py index 2f7f7ea2527..71744bf795e 100644 --- a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py +++ b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py @@ -1,12 +1,13 @@ import json import sys +from typing import Any, AsyncGenerator, Generator from unittest.mock import MagicMock, patch import pytest from langsmith import Client, traceable from langsmith.run_helpers import tracing_context -from langchain_core.runnables.base import RunnableLambda +from langchain_core.runnables.base import RunnableLambda, RunnableParallel from langchain_core.tracers.langchain import LangChainTracer @@ -199,3 +200,141 @@ def test_tracing_enable_disable( assert len(mock_posts) == 1 else: assert not mock_posts + + +@pytest.mark.parametrize( + "method", ["invoke", "stream", "batch", "ainvoke", "astream", "abatch"] +) +async def test_runnable_sequence_parallel_trace_nesting(method: str) -> None: + if method.startswith("a") and sys.version_info < (3, 11): + pytest.skip("Asyncio context vars require Python 3.11+") + mock_session = MagicMock() + mock_client_ = Client( + session=mock_session, api_key="test", auto_batch_tracing=False + ) + tracer = LangChainTracer(client=mock_client_) + + @RunnableLambda + def my_child_function(a: int) -> int: + return a + 2 + + if method.startswith("a"): + + async def other_thing(a: int) -> AsyncGenerator[int, None]: + yield 1 + + else: + + def other_thing(a: int) -> Generator[int, None, None]: # type: ignore + yield 1 + + parallel = RunnableParallel( + chain_result=my_child_function.with_config(tags=["atag"]), + other_thing=other_thing, + ) + + def before(x: int) -> int: + return x + + def after(x: dict) -> int: + return x["chain_result"] + + sequence = before | parallel | after + if method.startswith("a"): + + @RunnableLambda # type: ignore + async def parent(a: int) -> int: + return await sequence.ainvoke(a) + + else: + + @RunnableLambda + def parent(a: int) -> int: + return sequence.invoke(a) + + # Now run the chain and check the resulting posts + cb = [tracer] + if method == "invoke": + res: Any = parent.invoke(1, {"callbacks": cb}) # type: ignore + elif method == "ainvoke": + res = await parent.ainvoke(1, {"callbacks": cb}) # type: ignore + elif method == "stream": + results = list(parent.stream(1, {"callbacks": cb})) # type: ignore + res = results[-1] + elif method == "astream": + results = [res async for res in parent.astream(1, {"callbacks": cb})] # type: ignore + res = results[-1] + elif method == "batch": + res = parent.batch([1], {"callbacks": cb})[0] # type: ignore + elif method == "abatch": + res = (await parent.abatch([1], {"callbacks": cb}))[0] # type: ignore + else: + raise ValueError(f"Unknown method {method}") + assert res == 3 + posts = _get_posts(mock_client_) + name_order = [ + "parent", + "RunnableSequence", + "before", + "RunnableParallel", + ["my_child_function", "other_thing"], + "after", + ] + expected_parents = { + "parent": None, + "RunnableSequence": "parent", + "before": "RunnableSequence", + "RunnableParallel": "RunnableSequence", + "my_child_function": "RunnableParallel", + "other_thing": "RunnableParallel", + "after": "RunnableSequence", + } + assert len(posts) == sum([1 if isinstance(n, str) else len(n) for n in name_order]) + prev_dotted_order = None + dotted_order_map = {} + id_map = {} + parent_id_map = {} + i = 0 + for name in name_order: + if isinstance(name, list): + for n in name: + matching_post = next( + p for p in posts[i : i + len(name)] if p["name"] == n + ) + assert matching_post + dotted_order = matching_post["dotted_order"] + if prev_dotted_order is not None: + assert dotted_order > prev_dotted_order + dotted_order_map[n] = dotted_order + id_map[n] = matching_post["id"] + parent_id_map[n] = matching_post.get("parent_run_id") + i += len(name) + continue + else: + assert posts[i]["name"] == name + dotted_order = posts[i]["dotted_order"] + if prev_dotted_order is not None and not str( + expected_parents[name] + ).startswith("RunnableParallel"): + assert ( + dotted_order > prev_dotted_order + ), f"{name} not after {name_order[i-1]}" + prev_dotted_order = dotted_order + if name in dotted_order_map: + raise ValueError(f"Duplicate name {name}") + dotted_order_map[name] = dotted_order + id_map[name] = posts[i]["id"] + parent_id_map[name] = posts[i].get("parent_run_id") + i += 1 + + # Now check the dotted orders + for name, parent_ in expected_parents.items(): + dotted_order = dotted_order_map[name] + if parent_ is not None: + parent_dotted_order = dotted_order_map[parent_] + assert dotted_order.startswith( + parent_dotted_order + ), f"{name}, {parent_dotted_order} not in {dotted_order}" + assert str(parent_id_map[name]) == str(id_map[parent_]) + else: + assert dotted_order.split(".")[0] == dotted_order