diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 3fceb33f009..6908631869e 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -1183,6 +1183,7 @@ class CallbackManager(BaseCallbackManager): self, serialized: Dict[str, Any], prompts: List[str], + run_id: Optional[UUID] = None, **kwargs: Any, ) -> List[CallbackManagerForLLMRun]: """Run when LLM starts running. @@ -1197,8 +1198,9 @@ class CallbackManager(BaseCallbackManager): prompt as an LLM run. """ managers = [] - for prompt in prompts: - run_id_ = uuid.uuid4() + for i, prompt in enumerate(prompts): + # Can't have duplicate runs with the same run ID (if provided) + run_id_ = run_id if i == 0 and run_id is not None else uuid.uuid4() handle_event( self.handlers, "on_llm_start", @@ -1231,6 +1233,7 @@ class CallbackManager(BaseCallbackManager): self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], + run_id: Optional[UUID] = None, **kwargs: Any, ) -> List[CallbackManagerForLLMRun]: """Run when LLM starts running. @@ -1247,7 +1250,11 @@ class CallbackManager(BaseCallbackManager): managers = [] for message_list in messages: - run_id_ = uuid.uuid4() + if run_id is not None: + run_id_ = run_id + run_id = None + else: + run_id_ = uuid.uuid4() handle_event( self.handlers, "on_chat_model_start", @@ -1520,6 +1527,7 @@ class AsyncCallbackManager(BaseCallbackManager): self, serialized: Dict[str, Any], prompts: List[str], + run_id: Optional[UUID] = None, **kwargs: Any, ) -> List[AsyncCallbackManagerForLLMRun]: """Run when LLM starts running. @@ -1539,7 +1547,11 @@ class AsyncCallbackManager(BaseCallbackManager): managers = [] for prompt in prompts: - run_id_ = uuid.uuid4() + if run_id is not None: + run_id_ = run_id + run_id = None + else: + run_id_ = uuid.uuid4() tasks.append( ahandle_event( @@ -1577,6 +1589,7 @@ class AsyncCallbackManager(BaseCallbackManager): self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], + run_id: Optional[UUID] = None, **kwargs: Any, ) -> List[AsyncCallbackManagerForLLMRun]: """Run when LLM starts running. @@ -1595,7 +1608,11 @@ class AsyncCallbackManager(BaseCallbackManager): managers = [] for message_list in messages: - run_id_ = uuid.uuid4() + if run_id is not None: + run_id_ = run_id + run_id = None + else: + run_id_ = uuid.uuid4() tasks.append( ahandle_event( diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 7c346c24eab..8472dbcc033 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import inspect +import uuid import warnings from abc import ABC, abstractmethod from typing import ( @@ -234,6 +235,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): invocation_params=params, options=options, name=config.get("run_name"), + run_id=config.pop("run_id", None), batch_size=1, ) generation: Optional[ChatGenerationChunk] = None @@ -312,6 +314,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): invocation_params=params, options=options, name=config.get("run_name"), + run_id=config.pop("run_id", None), batch_size=1, ) generation: Optional[ChatGenerationChunk] = None @@ -371,6 +374,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, run_name: Optional[str] = None, + run_id: Optional[uuid.UUID] = None, **kwargs: Any, ) -> LLMResult: """Pass a sequence of prompts to the model and return model generations. @@ -415,6 +419,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): invocation_params=params, options=options, name=run_name, + run_id=run_id, batch_size=len(messages), ) results = [] @@ -456,6 +461,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, run_name: Optional[str] = None, + run_id: Optional[uuid.UUID] = None, **kwargs: Any, ) -> LLMResult: """Asynchronously pass a sequence of prompts to a model and return generations. @@ -502,6 +508,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): options=options, name=run_name, batch_size=len(messages), + run_id=run_id, ) results = await asyncio.gather( diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index f2fd7120286..789d9baa2ab 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -7,6 +7,7 @@ import functools import inspect import json import logging +import uuid import warnings from abc import ABC, abstractmethod from pathlib import Path @@ -271,6 +272,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): tags=config.get("tags"), metadata=config.get("metadata"), run_name=config.get("run_name"), + run_id=config.pop("run_id", None), **kwargs, ) .generations[0][0] @@ -293,6 +295,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): tags=config.get("tags"), metadata=config.get("metadata"), run_name=config.get("run_name"), + run_id=config.pop("run_id", None), **kwargs, ) return llm_result.generations[0][0].text @@ -423,6 +426,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): invocation_params=params, options=options, name=config.get("run_name"), + run_id=config.pop("run_id", None), batch_size=1, ) generation: Optional[GenerationChunk] = None @@ -499,6 +503,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): invocation_params=params, options=options, name=config.get("run_name"), + run_id=config.pop("run_id", None), batch_size=1, ) generation: Optional[GenerationChunk] = None @@ -632,6 +637,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): tags: Optional[Union[List[str], List[List[str]]]] = None, metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, run_name: Optional[Union[str, List[str]]] = None, + run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]] = None, **kwargs: Any, ) -> LLMResult: """Pass a sequence of prompts to a model and return generations. @@ -717,7 +723,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): ) ] * len(prompts) run_name_list = [cast(Optional[str], run_name)] * len(prompts) - + run_ids_list = self._get_run_ids_list(run_id, prompts) params = self.dict() params["stop"] = stop options = {"stop": stop} @@ -744,9 +750,10 @@ class BaseLLM(BaseLanguageModel[str], ABC): options=options, name=run_name, batch_size=len(prompts), + run_id=run_id_, )[0] - for callback_manager, prompt, run_name in zip( - callback_managers, prompts, run_name_list + for callback_manager, prompt, run_name, run_id_ in zip( + callback_managers, prompts, run_name_list, run_ids_list ) ] output = self._generate_helper( @@ -782,6 +789,21 @@ class BaseLLM(BaseLanguageModel[str], ABC): generations = [existing_prompts[i] for i in range(len(prompts))] return LLMResult(generations=generations, llm_output=llm_output, run=run_info) + @staticmethod + def _get_run_ids_list( + run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]], prompts: list + ) -> list: + if run_id is None: + return [None] * len(prompts) + if isinstance(run_id, list): + if len(run_id) != len(prompts): + raise ValueError( + "Number of manually provided run_id's does not match batch length." + f" {len(run_id)} != {len(prompts)}" + ) + return run_id + return [run_id] + [None] * (len(prompts) - 1) + async def _agenerate_helper( self, prompts: List[str], @@ -833,6 +855,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): tags: Optional[Union[List[str], List[List[str]]]] = None, metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, run_name: Optional[Union[str, List[str]]] = None, + run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]] = None, **kwargs: Any, ) -> LLMResult: """Asynchronously pass a sequence of prompts to a model and return generations. @@ -909,7 +932,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): ) ] * len(prompts) run_name_list = [cast(Optional[str], run_name)] * len(prompts) - + run_ids_list = self._get_run_ids_list(run_id, prompts) params = self.dict() params["stop"] = stop options = {"stop": stop} @@ -937,9 +960,10 @@ class BaseLLM(BaseLanguageModel[str], ABC): options=options, name=run_name, batch_size=len(prompts), + run_id=run_id_, ) - for callback_manager, prompt, run_name in zip( - callback_managers, prompts, run_name_list + for callback_manager, prompt, run_name, run_id_ in zip( + callback_managers, prompts, run_name_list, run_ids_list ) ] ) diff --git a/libs/core/langchain_core/retrievers.py b/libs/core/langchain_core/retrievers.py index 3854184d377..b7c847150b7 100644 --- a/libs/core/langchain_core/retrievers.py +++ b/libs/core/langchain_core/retrievers.py @@ -230,6 +230,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): dumpd(self), query, name=run_name, + run_id=kwargs.pop("run_id", None), ) try: _kwargs = kwargs if self._expects_other_args else {} @@ -286,6 +287,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): dumpd(self), query, name=run_name, + run_id=kwargs.pop("run_id", None), ) try: _kwargs = kwargs if self._expects_other_args else {} diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 82ea735e11e..88d71407ca4 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -1448,6 +1448,7 @@ class Runnable(Generic[Input, Output], ABC): input, run_type=run_type, name=config.get("run_name") or self.get_name(), + run_id=config.pop("run_id", None), ) try: child_config = patch_config(config, callbacks=run_manager.get_child()) @@ -1495,6 +1496,7 @@ class Runnable(Generic[Input, Output], ABC): input, run_type=run_type, name=config.get("run_name") or self.get_name(), + run_id=config.pop("run_id", None), ) try: child_config = patch_config(config, callbacks=run_manager.get_child()) @@ -1547,6 +1549,7 @@ class Runnable(Generic[Input, Output], ABC): input, run_type=run_type, name=config.get("run_name") or self.get_name(), + run_id=config.pop("run_id", None), ) for callback_manager, input, config in zip( callback_managers, input, configs @@ -1619,6 +1622,7 @@ class Runnable(Generic[Input, Output], ABC): input, run_type=run_type, name=config.get("run_name") or self.get_name(), + run_id=config.pop("run_id", None), ) for callback_manager, input, config in zip( callback_managers, input, configs @@ -1694,6 +1698,7 @@ class Runnable(Generic[Input, Output], ABC): {"input": ""}, run_type=run_type, name=config.get("run_name") or self.get_name(), + run_id=config.pop("run_id", None), ) try: child_config = patch_config(config, callbacks=run_manager.get_child()) @@ -1781,6 +1786,7 @@ class Runnable(Generic[Input, Output], ABC): {"input": ""}, run_type=run_type, name=config.get("run_name") or self.get_name(), + run_id=config.pop("run_id", None), ) try: child_config = patch_config(config, callbacks=run_manager.get_child()) @@ -2262,7 +2268,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]): callback_manager = get_callback_manager_for_config(config) # start the root run run_manager = callback_manager.on_chain_start( - dumpd(self), input, name=config.get("run_name") or self.get_name() + dumpd(self), + input, + name=config.get("run_name") or self.get_name(), + run_id=config.pop("run_id", None), ) # invoke all steps in sequence @@ -2296,7 +2305,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]): callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( - dumpd(self), input, name=config.get("run_name") or self.get_name() + dumpd(self), + input, + name=config.get("run_name") or self.get_name(), + run_id=config.pop("run_id", None), ) # invoke all steps in sequence @@ -2354,6 +2366,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): dumpd(self), input, name=config.get("run_name") or self.get_name(), + run_id=config.pop("run_id", None), ) for cm, input, config in zip(callback_managers, inputs, configs) ] @@ -2478,6 +2491,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): dumpd(self), input, name=config.get("run_name") or self.get_name(), + run_id=config.pop("run_id", None), ) for cm, input, config in zip(callback_managers, inputs, configs) ) @@ -2885,7 +2899,10 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): ) # start the root run run_manager = callback_manager.on_chain_start( - dumpd(self), input, name=config.get("run_name") or self.get_name() + dumpd(self), + input, + name=config.get("run_name") or self.get_name(), + run_id=config.pop("run_id", None), ) # gather results from all steps @@ -2925,7 +2942,10 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( - dumpd(self), input, name=config.get("run_name") or self.get_name() + dumpd(self), + input, + name=config.get("run_name") or self.get_name(), + run_id=config.pop("run_id", None), ) # gather results from all steps diff --git a/libs/core/langchain_core/runnables/branch.py b/libs/core/langchain_core/runnables/branch.py index 38cec8b32cd..ddafdd2a6e7 100644 --- a/libs/core/langchain_core/runnables/branch.py +++ b/libs/core/langchain_core/runnables/branch.py @@ -183,6 +183,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]): dumpd(self), input, name=config.get("run_name"), + run_id=config.pop("run_id", None), ) try: @@ -231,6 +232,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]): dumpd(self), input, name=config.get("run_name"), + run_id=config.pop("run_id", None), ) try: for idx, branch in enumerate(self.branches): @@ -282,6 +284,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]): dumpd(self), input, name=config.get("run_name"), + run_id=config.pop("run_id", None), ) final_output: Optional[Output] = None final_output_supported = True @@ -356,6 +359,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]): dumpd(self), input, name=config.get("run_name"), + run_id=config.pop("run_id", None), ) final_output: Optional[Output] = None final_output_supported = True diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py index 015cb8f4664..b89a3da38fc 100644 --- a/libs/core/langchain_core/runnables/config.py +++ b/libs/core/langchain_core/runnables/config.py @@ -1,6 +1,8 @@ from __future__ import annotations import asyncio +import uuid +import warnings from concurrent.futures import Executor, Future, ThreadPoolExecutor from contextlib import contextmanager from contextvars import ContextVar, copy_context @@ -95,6 +97,12 @@ class RunnableConfig(TypedDict, total=False): configurable. """ + run_id: Optional[uuid.UUID] + """ + Unique identifier for the tracer run for this call. If not provided, a new UUID + will be generated. + """ + var_child_runnable_config = ContextVar( "child_runnable_config", default=RunnableConfig() @@ -116,6 +124,7 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: metadata={}, callbacks=None, recursion_limit=25, + run_id=None, ) if var_config := var_child_runnable_config.get(): empty.update( @@ -158,11 +167,21 @@ def get_config_list( f"but got {len(config)} configs for {length} inputs" ) - return ( - list(map(ensure_config, config)) - if isinstance(config, list) - else [ensure_config(config) for _ in range(length)] - ) + if isinstance(config, list): + return list(map(ensure_config, config)) + if length > 1 and isinstance(config, dict) and config.get("run_id") is not None: + warnings.warn( + "Provided run_id be used only for the first element of the batch.", + category=RuntimeWarning, + ) + subsequent = cast( + RunnableConfig, {k: v for k, v in config.items() if k != "run_id"} + ) + return [ + ensure_config(subsequent) if i else ensure_config(config) + for i in range(length) + ] + return [ensure_config(config) for i in range(length)] def patch_config( @@ -199,6 +218,8 @@ def patch_config( config["callbacks"] = callbacks if "run_name" in config: del config["run_name"] + if "run_id" in config: + del config["run_id"] if recursion_limit is not None: config["recursion_limit"] = recursion_limit if max_concurrency is not None: diff --git a/libs/core/langchain_core/runnables/fallbacks.py b/libs/core/langchain_core/runnables/fallbacks.py index 9ecbe6cdfc6..131061e9e9f 100644 --- a/libs/core/langchain_core/runnables/fallbacks.py +++ b/libs/core/langchain_core/runnables/fallbacks.py @@ -156,7 +156,10 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): callback_manager = get_callback_manager_for_config(config) # start the root run run_manager = callback_manager.on_chain_start( - dumpd(self), input, name=config.get("run_name") + dumpd(self), + input, + name=config.get("run_name"), + run_id=config.pop("run_id", None), ) first_error = None last_error = None @@ -200,7 +203,10 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( - dumpd(self), input, name=config.get("run_name") + dumpd(self), + input, + name=config.get("run_name"), + run_id=config.pop("run_id", None), ) first_error = None @@ -270,6 +276,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): dumpd(self), input if isinstance(input, dict) else {"input": input}, name=config.get("run_name"), + run_id=config.pop("run_id", None), ) for cm, input, config in zip(callback_managers, inputs, configs) ] @@ -362,6 +369,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): dumpd(self), input, name=config.get("run_name"), + run_id=config.pop("run_id", None), ) for cm, input, config in zip(callback_managers, inputs, configs) ) @@ -436,7 +444,10 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): callback_manager = get_callback_manager_for_config(config) # start the root run run_manager = callback_manager.on_chain_start( - dumpd(self), input, name=config.get("run_name") + dumpd(self), + input, + name=config.get("run_name"), + run_id=config.pop("run_id", None), ) first_error = None last_error = None @@ -493,7 +504,10 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( - dumpd(self), input, name=config.get("run_name") + dumpd(self), + input, + name=config.get("run_name"), + run_id=config.pop("run_id", None), ) first_error = None last_error = None diff --git a/libs/core/langchain_core/runnables/learnable.py b/libs/core/langchain_core/runnables/learnable.py new file mode 100644 index 00000000000..f56815858bc --- /dev/null +++ b/libs/core/langchain_core/runnables/learnable.py @@ -0,0 +1,15 @@ +# from langchain_core.runnables.base import RunnableBinding + + +# class RunnableLearnable(RunnableBinding): +# def __init__(self, *args, **kwargs): +# super().__init__(*args, **kwargs) +# self.parameters = [] + +# def backward(self): +# for param in self.parameters: +# param.backward() + +# def update(self, optimizer): +# for param in self.parameters: +# optimizer.update(param) diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index a67a642329c..585bd1d53e6 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -20,6 +20,7 @@ tool for the job. from __future__ import annotations import inspect +import uuid import warnings from abc import abstractmethod from inspect import signature @@ -243,6 +244,7 @@ class ChildTool(BaseTool): tags=config.get("tags"), metadata=config.get("metadata"), run_name=config.get("run_name"), + run_id=config.pop("run_id", None), **kwargs, ) @@ -259,6 +261,7 @@ class ChildTool(BaseTool): tags=config.get("tags"), metadata=config.get("metadata"), run_name=config.get("run_name"), + run_id=config.pop("run_id", None), **kwargs, ) @@ -339,6 +342,7 @@ class ChildTool(BaseTool): tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, run_name: Optional[str] = None, + run_id: Optional[uuid.UUID] = None, **kwargs: Any, ) -> Any: """Run the tool.""" @@ -362,6 +366,7 @@ class ChildTool(BaseTool): tool_input if isinstance(tool_input, str) else str(tool_input), color=start_color, name=run_name, + run_id=run_id, # Inputs by definition should always be dicts. # For now, it's unclear whether this assumption is ever violated, # but if it is we will send a `None` value to the callback instead @@ -430,6 +435,7 @@ class ChildTool(BaseTool): tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, run_name: Optional[str] = None, + run_id: Optional[uuid.UUID] = None, **kwargs: Any, ) -> Any: """Run the tool asynchronously.""" @@ -453,6 +459,7 @@ class ChildTool(BaseTool): color=start_color, name=run_name, inputs=tool_input, + run_id=run_id, **kwargs, ) try: diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index c480d5bb22a..80b5dfa054b 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -1,4 +1,5 @@ import sys +import uuid from functools import partial from operator import itemgetter from typing import ( @@ -136,6 +137,22 @@ class FakeTracer(BaseTracer): self.runs.append(self._copy_run(run)) + def flattened_runs(self) -> List[Run]: + q = [] + self.runs + result = [] + while q: + parent = q.pop() + result.append(parent) + if parent.child_runs: + q.extend(parent.child_runs) + return result + + @property + def run_ids(self) -> List[Optional[uuid.UUID]]: + runs = self.flattened_runs() + uuids_map = {v: k for k, v in self.uuids_map.items()} + return [uuids_map.get(r.id) for r in runs] + class FakeRunnable(Runnable[str, int]): def invoke( @@ -1367,6 +1384,7 @@ async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None: recursion_limit=25, configurable={"hello": "there"}, metadata={"hello": "there", "bye": "now"}, + run_id=None, ), ) spy.reset_mock() @@ -1508,6 +1526,7 @@ async def test_with_config(mocker: MockerFixture) -> None: tags=["c"], callbacks=None, recursion_limit=5, + run_id=None, ), ), mocker.call( @@ -1517,6 +1536,7 @@ async def test_with_config(mocker: MockerFixture) -> None: tags=["c"], callbacks=None, recursion_limit=5, + run_id=None, ), ), ] @@ -1542,6 +1562,7 @@ async def test_with_config(mocker: MockerFixture) -> None: tags=["c"], callbacks=None, recursion_limit=5, + run_id=None, ), ) second_call = next(call for call in spy.call_args_list if call.args[0] == "wooorld") @@ -1552,6 +1573,7 @@ async def test_with_config(mocker: MockerFixture) -> None: tags=["c"], callbacks=None, recursion_limit=5, + run_id=None, ), ) @@ -1620,6 +1642,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: tags=[], callbacks=None, recursion_limit=25, + run_id=None, ), ), mocker.call( @@ -1629,6 +1652,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: tags=[], callbacks=None, recursion_limit=25, + run_id=None, ), ), ] @@ -4822,27 +4846,45 @@ async def test_runnable_gen_context_config() -> None: } tracer = FakeTracer() - assert runnable.invoke(None, {"callbacks": [tracer]}) == 6 + run_id = uuid.uuid4() + assert runnable.invoke(None, {"callbacks": [tracer], "run_id": run_id}) == 6 assert len(tracer.runs) == 1 assert tracer.runs[0].outputs == {"output": 6} assert len(tracer.runs[0].child_runs) == 3 assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] + run_ids = tracer.run_ids + assert run_id in run_ids + assert len(run_ids) == len(set(run_ids)) tracer.runs.clear() assert list(runnable.stream(None)) == [1, 2, 3] assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call" tracer = FakeTracer() - assert list(runnable.stream(None, {"callbacks": [tracer]})) == [1, 2, 3] + run_id = uuid.uuid4() + assert list(runnable.stream(None, {"callbacks": [tracer], "run_id": run_id})) == [ + 1, + 2, + 3, + ] assert len(tracer.runs) == 1 assert tracer.runs[0].outputs == {"output": 6} assert len(tracer.runs[0].child_runs) == 3 assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] + run_ids = tracer.run_ids + assert run_id in run_ids + assert len(run_ids) == len(set(run_ids)) + tracer.runs.clear() tracer = FakeTracer() - assert runnable.batch([None, None], {"callbacks": [tracer]}) == [6, 6] + run_id = uuid.uuid4() + + with pytest.warns(RuntimeWarning): + assert runnable.batch( + [None, None], {"callbacks": [tracer], "run_id": run_id} + ) == [6, 6] assert len(tracer.runs) == 2 assert tracer.runs[0].outputs == {"output": 6} assert tracer.runs[1].outputs == {"output": 6} @@ -4865,19 +4907,30 @@ async def test_runnable_gen_context_config() -> None: arunnable = RunnableGenerator(agen) tracer = FakeTracer() - assert await arunnable.ainvoke(None, {"callbacks": [tracer]}) == 6 + + run_id = uuid.uuid4() + assert await arunnable.ainvoke(None, {"callbacks": [tracer], "run_id": run_id}) == 6 assert len(tracer.runs) == 1 assert tracer.runs[0].outputs == {"output": 6} assert len(tracer.runs[0].child_runs) == 3 assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] + run_ids = tracer.run_ids + assert run_id in run_ids + assert len(run_ids) == len(set(run_ids)) tracer.runs.clear() assert [p async for p in arunnable.astream(None)] == [1, 2, 3] assert len(tracer.runs) == 0, "callbacks doesn't persist from previous call" tracer = FakeTracer() - assert [p async for p in arunnable.astream(None, {"callbacks": [tracer]})] == [ + run_id = uuid.uuid4() + assert [ + p + async for p in arunnable.astream( + None, {"callbacks": [tracer], "run_id": run_id} + ) + ] == [ 1, 2, 3, @@ -4887,9 +4940,16 @@ async def test_runnable_gen_context_config() -> None: assert len(tracer.runs[0].child_runs) == 3 assert [r.inputs["input"] for r in tracer.runs[0].child_runs] == ["a", "aa", "aaa"] assert [(r.outputs or {})["output"] for r in tracer.runs[0].child_runs] == [1, 2, 3] + run_ids = tracer.run_ids + assert run_id in run_ids + assert len(run_ids) == len(set(run_ids)) tracer = FakeTracer() - assert await arunnable.abatch([None, None], {"callbacks": [tracer]}) == [6, 6] + run_id = uuid.uuid4() + with pytest.warns(RuntimeWarning): + assert await arunnable.abatch( + [None, None], {"callbacks": [tracer], "run_id": run_id} + ) == [6, 6] assert len(tracer.runs) == 2 assert tracer.runs[0].outputs == {"output": 6} assert tracer.runs[1].outputs == {"output": 6}