From a3c69cf41d6b8584a9cbae98813e3c028b75624e Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 24 Aug 2023 11:53:29 +0200 Subject: [PATCH 01/12] Add .with_config() method to Runnables which allows binding any config values to a Runnable --- .../langchain/schema/runnable/base.py | 80 +++++++++++-- .../langchain/schema/runnable/config.py | 5 +- .../runnable/__snapshots__/test_runnable.ambr | 3 +- .../schema/runnable/test_runnable.py | 106 ++++++++++++++++++ 4 files changed, 181 insertions(+), 13 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index bdbd7fc6999..98f7d758635 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -210,7 +210,20 @@ class Runnable(Generic[Input, Output], ABC): """ Bind arguments to a Runnable, returning a new Runnable. """ - return RunnableBinding(bound=self, kwargs=kwargs) + return RunnableBinding(bound=self, kwargs=kwargs, config={}) + + def with_config( + self, + config: Optional[RunnableConfig] = None, + # Sadly Unpack is not well supported by mypy so this will have to be untyped + **kwargs: Any, + ) -> Runnable[Input, Output]: + """ + Bind config to a Runnable, returning a new Runnable. + """ + return RunnableBinding( + bound=self, config={**(config or {}), **kwargs}, kwargs={} + ) def map(self) -> Runnable[List[Input], List[Output]]: """ @@ -1479,6 +1492,8 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): kwargs: Mapping[str, Any] + config: Mapping[str, Any] = Field(default_factory=dict) + class Config: arbitrary_types_allowed = True @@ -1491,7 +1506,21 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): return self.__class__.__module__.split(".")[:-1] def bind(self, **kwargs: Any) -> Runnable[Input, Output]: - return self.__class__(bound=self.bound, kwargs={**self.kwargs, **kwargs}) + return self.__class__( + bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs} + ) + + def with_config( + self, + config: Optional[RunnableConfig] = None, + # Sadly Unpack is not well supported by mypy so this will have to be untyped + **kwargs: Any, + ) -> Runnable[Input, Output]: + return self.__class__( + bound=self.bound, + kwargs=self.kwargs, + config={**self.config, **(config or {}), **kwargs}, + ) def invoke( self, @@ -1499,7 +1528,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Output: - return self.bound.invoke(input, config, **{**self.kwargs, **kwargs}) + return self.bound.invoke( + input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs} + ) async def ainvoke( self, @@ -1507,7 +1538,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Output: - return await self.bound.ainvoke(input, config, **{**self.kwargs, **kwargs}) + return await self.bound.ainvoke( + input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs} + ) def batch( self, @@ -1515,7 +1548,15 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, **kwargs: Optional[Any], ) -> List[Output]: - return self.bound.batch(inputs, config, **{**self.kwargs, **kwargs}) + configs = ( + [{**self.config, **(conf or {})} for conf in config] + if isinstance(config, list) + else [ + patch_config({**self.config, **(config or {})}, deep_copy_locals=True) + for _ in range(len(inputs)) + ] + ) + return self.bound.batch(inputs, configs, **{**self.kwargs, **kwargs}) async def abatch( self, @@ -1523,7 +1564,19 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, **kwargs: Optional[Any], ) -> List[Output]: - return await self.bound.abatch(inputs, config, **{**self.kwargs, **kwargs}) + configs = ( + [{**self.config, **(conf or {})} for conf in config] + if isinstance(config, list) + else [ + patch_config({**self.config, **(config or {})}, deep_copy_locals=True) + for _ in range(len(inputs)) + ] + ) + return await self.bound.abatch( + inputs, + [{**self.config, **(conf or {})} for conf in configs], + **{**self.kwargs, **kwargs}, + ) def stream( self, @@ -1531,7 +1584,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Iterator[Output]: - yield from self.bound.stream(input, config, **{**self.kwargs, **kwargs}) + yield from self.bound.stream( + input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs} + ) async def astream( self, @@ -1540,7 +1595,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): **kwargs: Optional[Any], ) -> AsyncIterator[Output]: async for item in self.bound.astream( - input, config, **{**self.kwargs, **kwargs} + input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs} ): yield item @@ -1550,7 +1605,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Iterator[Output]: - yield from self.bound.transform(input, config, **{**self.kwargs, **kwargs}) + yield from self.bound.transform( + input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs} + ) async def atransform( self, @@ -1559,11 +1616,14 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): **kwargs: Any, ) -> AsyncIterator[Output]: async for item in self.bound.atransform( - input, config, **{**self.kwargs, **kwargs} + input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs} ): yield item +RunnableBinding.update_forward_refs(RunnableConfig=RunnableConfig) + + def coerce_to_runnable( thing: Union[ Runnable[Input, Output], diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index b97d904414d..10cf15c0d88 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -3,7 +3,8 @@ from __future__ import annotations from concurrent.futures import Executor, ThreadPoolExecutor from contextlib import contextmanager from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, TypedDict +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional +from typing_extensions import TypedDict if TYPE_CHECKING: from langchain.callbacks.base import BaseCallbackManager, Callbacks @@ -48,7 +49,7 @@ class RunnableConfig(TypedDict, total=False): """ -def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: +def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: empty = RunnableConfig( tags=[], metadata={}, diff --git a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr index c48d4edbd41..fcb621fe8cf 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr +++ b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr @@ -2081,7 +2081,8 @@ "stop": [ "Thought:" ] - } + }, + "config": {} } }, "llm": { diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index f2447533107..32f488e10c9 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -112,6 +112,104 @@ class FakeRetriever(BaseRetriever): return [Document(page_content="foo"), Document(page_content="bar")] +@pytest.mark.asyncio +async def test_with_config(mocker: MockerFixture) -> None: + fake = FakeRunnable() + spy = mocker.spy(fake, "invoke") + + assert fake.with_config(tags=["a-tag"]).invoke("hello") == 5 + assert spy.call_args_list == [ + mocker.call("hello", dict(tags=["a-tag"])), + ] + spy.reset_mock() + + assert [ + *fake.with_config(tags=["a-tag"]).stream( + "hello", dict(metadata={"key": "value"}) + ) + ] == [5] + assert spy.call_args_list == [ + mocker.call("hello", dict(tags=["a-tag"], metadata={"key": "value"})), + ] + spy.reset_mock() + + assert fake.with_config(recursion_limit=5).batch( + ["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})] + ) == [5, 7] + + assert len(spy.call_args_list) == 2 + for i, call in enumerate(spy.call_args_list): + assert call.args[0] == ("hello" if i == 0 else "wooorld") + if i == 0: + assert call.args[1].get("recursion_limit") == 5 + assert call.args[1].get("tags") == ["a-tag"] + assert call.args[1].get("metadata") == {} + else: + assert call.args[1].get("recursion_limit") == 5 + assert call.args[1].get("tags") == [] + assert call.args[1].get("metadata") == {"key": "value"} + + spy.reset_mock() + + assert fake.with_config(metadata={"a": "b"}).batch( + ["hello", "wooorld"], dict(tags=["a-tag"]) + ) == [5, 7] + assert len(spy.call_args_list) == 2 + for i, call in enumerate(spy.call_args_list): + assert call.args[0] == ("hello" if i == 0 else "wooorld") + assert call.args[1].get("tags") == ["a-tag"] + assert call.args[1].get("metadata") == {"a": "b"} + spy.reset_mock() + + assert ( + await fake.with_config(metadata={"a": "b"}).ainvoke( + "hello", config={"callbacks": []} + ) + == 5 + ) + assert spy.call_args_list == [ + mocker.call("hello", dict(callbacks=[], metadata={"a": "b"})), + ] + spy.reset_mock() + + assert [ + part async for part in fake.with_config(metadata={"a": "b"}).astream("hello") + ] == [5] + assert spy.call_args_list == [ + mocker.call("hello", dict(metadata={"a": "b"})), + ] + spy.reset_mock() + + assert await fake.with_config(recursion_limit=5, tags=["c"]).abatch( + ["hello", "wooorld"], dict(metadata={"key": "value"}) + ) == [ + 5, + 7, + ] + assert spy.call_args_list == [ + mocker.call( + "hello", + dict( + metadata={"key": "value"}, + tags=["c"], + callbacks=None, + _locals={}, + recursion_limit=5, + ), + ), + mocker.call( + "wooorld", + dict( + metadata={"key": "value"}, + tags=["c"], + callbacks=None, + _locals={}, + recursion_limit=5, + ), + ), + ] + + @pytest.mark.asyncio async def test_default_method_implementations(mocker: MockerFixture) -> None: fake = FakeRunnable() @@ -1125,6 +1223,14 @@ async def test_map_astream_iterator_input() -> None: assert final_value.get("passthrough") == llm_res +def test_with_config_with_config() -> None: + llm = FakeListLLM(responses=["i'm a textbot"]) + + assert dumpd( + llm.with_config({"metadata": {"a": "b"}}).with_config(tags=["a-tag"]) + ) == dumpd(llm.with_config({"metadata": {"a": "b"}, "tags": ["a-tag"]})) + + def test_bind_bind() -> None: llm = FakeListLLM(responses=["i'm a textbot"]) From f69155b4f7a6a4c0181db90204aac18982ae9b69 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 24 Aug 2023 12:50:37 +0200 Subject: [PATCH 02/12] Add run_id, run_name to RunnableConfig --- libs/langchain/langchain/chains/base.py | 13 ++ libs/langchain/langchain/llms/base.py | 5 +- libs/langchain/langchain/schema/retriever.py | 5 + .../langchain/schema/runnable/base.py | 158 +++++++++++------- .../langchain/schema/runnable/config.py | 36 +++- .../langchain/schema/runnable/router.py | 6 +- libs/langchain/langchain/tools/base.py | 5 + 7 files changed, 162 insertions(+), 66 deletions(-) diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 701da6c375b..348ce9527b9 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -8,6 +8,7 @@ from abc import ABC, abstractmethod from functools import partial from pathlib import Path from typing import Any, Dict, List, Optional, Union +from uuid import UUID import yaml @@ -68,6 +69,8 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): callbacks=config.get("callbacks"), tags=config.get("tags"), metadata=config.get("metadata"), + run_id=config.get("run_id"), + run_name=config.get("run_name"), **kwargs, ) @@ -89,6 +92,8 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): callbacks=config.get("callbacks"), tags=config.get("tags"), metadata=config.get("metadata"), + run_id=config.get("run_id"), + run_name=config.get("run_name"), **kwargs, ) @@ -235,6 +240,8 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): *, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, + run_id: Optional[UUID] = None, + run_name: Optional[str] = None, include_run_info: bool = False, ) -> Dict[str, Any]: """Execute the chain. @@ -276,6 +283,8 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): run_manager = callback_manager.on_chain_start( dumpd(self), inputs, + run_id=run_id, + name=run_name, ) try: outputs = ( @@ -302,6 +311,8 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): *, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, + run_id: Optional[UUID] = None, + run_name: Optional[str] = None, include_run_info: bool = False, ) -> Dict[str, Any]: """Asynchronously execute the chain. @@ -343,6 +354,8 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): run_manager = await callback_manager.on_chain_start( dumpd(self), inputs, + run_id=run_id, + name=run_name, ) try: outputs = ( diff --git a/libs/langchain/langchain/llms/base.py b/libs/langchain/langchain/llms/base.py index a91ecd9f2ac..6f7dcc2008b 100644 --- a/libs/langchain/langchain/llms/base.py +++ b/libs/langchain/langchain/llms/base.py @@ -60,6 +60,7 @@ from langchain.schema.language_model import BaseLanguageModel, LanguageModelInpu from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string from langchain.schema.output import GenerationChunk from langchain.schema.runnable import RunnableConfig +from langchain.schema.runnable.config import get_config_list logger = logging.getLogger(__name__) @@ -265,7 +266,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): max_concurrency: Optional[int] = None, **kwargs: Any, ) -> List[str]: - config = self._get_config_list(config, len(inputs)) + config = get_config_list(config, len(inputs)) if max_concurrency is None: llm_result = self.generate_prompt( @@ -300,7 +301,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): None, self.batch, inputs, config, max_concurrency ) - config = self._get_config_list(config, len(inputs)) + config = get_config_list(config, len(inputs)) if max_concurrency is None: llm_result = await self.agenerate_prompt( diff --git a/libs/langchain/langchain/schema/retriever.py b/libs/langchain/langchain/schema/retriever.py index 5da50e1497e..ba522172fb5 100644 --- a/libs/langchain/langchain/schema/retriever.py +++ b/libs/langchain/langchain/schema/retriever.py @@ -4,6 +4,7 @@ import warnings from abc import ABC, abstractmethod from inspect import signature from typing import TYPE_CHECKING, Any, Dict, List, Optional +from uuid import UUID from langchain.load.dump import dumpd from langchain.load.serializable import Serializable @@ -164,6 +165,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): callbacks: Callbacks = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, + run_id: Optional[UUID] = None, **kwargs: Any, ) -> List[Document]: """Retrieve documents relevant to a query. @@ -193,6 +195,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): run_manager = callback_manager.on_retriever_start( dumpd(self), query, + run_id=run_id, **kwargs, ) try: @@ -220,6 +223,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): callbacks: Callbacks = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, + run_id: Optional[UUID] = None, **kwargs: Any, ) -> List[Document]: """Asynchronously get documents relevant to a query. @@ -249,6 +253,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): run_manager = await callback_manager.on_retriever_start( dumpd(self), query, + run_id=run_id, **kwargs, ) try: diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 98f7d758635..96878c6969c 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -42,6 +42,7 @@ from langchain.schema.runnable.config import ( ensure_config, get_async_callback_manager_for_config, get_callback_manager_for_config, + get_config_list, get_executor_for_config, patch_config, ) @@ -110,7 +111,7 @@ class Runnable(Generic[Input, Output], ABC): Default implementation of batch, which calls invoke N times. Subclasses should override this method if they can batch more efficiently. """ - configs = self._get_config_list(config, len(inputs)) + configs = get_config_list(config, len(inputs)) # If there's only one input, don't bother with the executor if len(inputs) == 1: @@ -129,7 +130,7 @@ class Runnable(Generic[Input, Output], ABC): Default implementation of abatch, which calls ainvoke N times. Subclasses should override this method if they can batch more efficiently. """ - configs = self._get_config_list(config, len(inputs)) + configs = get_config_list(config, len(inputs)) coros = map(partial(self.ainvoke, **kwargs), inputs, configs) return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros) @@ -246,27 +247,6 @@ class Runnable(Generic[Input, Output], ABC): """ --- Helper methods for Subclasses --- """ - def _get_config_list( - self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int - ) -> List[RunnableConfig]: - """ - Helper method to get a list of configs from a single config or a list of - configs, useful for subclasses overriding batch() or abatch(). - """ - if length < 1: - raise ValueError(f"length must be >= 1, but got {length}") - if isinstance(config, list) and len(config) != length: - raise ValueError( - f"config must be a list of the same length as inputs, " - f"but got {len(config)} configs for {length} inputs" - ) - - return ( - list(map(ensure_config, config)) - if isinstance(config, list) - else [patch_config(config, deep_copy_locals=True) for _ in range(length)] - ) - def _call_with_config( self, func: Union[ @@ -286,6 +266,8 @@ class Runnable(Generic[Input, Output], ABC): dumpd(self), input, run_type=run_type, + run_id=config.get("run_id"), + name=config.get("run_name"), ) try: if accepts_run_manager_and_config(func): @@ -327,6 +309,8 @@ class Runnable(Generic[Input, Output], ABC): dumpd(self), input, run_type=run_type, + run_id=config.get("run_id"), + name=config.get("run_name"), ) try: if accepts_run_manager_and_config(func): @@ -384,6 +368,8 @@ class Runnable(Generic[Input, Output], ABC): dumpd(self), {"input": ""}, run_type=run_type, + run_id=config.get("run_id"), + name=config.get("run_name"), ) try: if accepts_run_manager_and_config(transformer): @@ -464,6 +450,8 @@ class Runnable(Generic[Input, Output], ABC): dumpd(self), {"input": ""}, run_type=run_type, + run_id=config.get("run_id"), + name=config.get("run_name"), ) try: # mypy can't quite work out thew type guard here, but this is safe, @@ -539,7 +527,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): config = ensure_config(config) callback_manager = get_callback_manager_for_config(config) # start the root run - run_manager = callback_manager.on_chain_start(dumpd(self), input) + run_manager = callback_manager.on_chain_start( + dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name") + ) first_error = None for runnable in self.runnables: try: @@ -571,7 +561,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): config = ensure_config(config) callback_manager = get_async_callback_manager_for_config(config) # start the root run - run_manager = await callback_manager.on_chain_start(dumpd(self), input) + run_manager = await callback_manager.on_chain_start( + dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name") + ) first_error = None for runnable in self.runnables: @@ -603,7 +595,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): from langchain.callbacks.manager import CallbackManager # setup callbacks - configs = self._get_config_list(config, len(inputs)) + configs = get_config_list(config, len(inputs)) callback_managers = [ CallbackManager.configure( inheritable_callbacks=config.get("callbacks"), @@ -619,9 +611,12 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): # start the root runs, one per input run_managers = [ cm.on_chain_start( - dumpd(self), input if isinstance(input, dict) else {"input": input} + dumpd(self), + input if isinstance(input, dict) else {"input": input}, + run_id=config.get("run_id"), + name=config.get("run_name"), ) - for cm, input in zip(callback_managers, inputs) + for cm, input, config in zip(callback_managers, inputs, configs) ] first_error = None @@ -661,7 +656,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): from langchain.callbacks.manager import AsyncCallbackManager # setup callbacks - configs = self._get_config_list(config, len(inputs)) + configs = get_config_list(config, len(inputs)) callback_managers = [ AsyncCallbackManager.configure( inheritable_callbacks=config.get("callbacks"), @@ -677,8 +672,13 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): # start the root runs, one per input run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( *( - cm.on_chain_start(dumpd(self), input) - for cm, input in zip(callback_managers, inputs) + cm.on_chain_start( + dumpd(self), + input, + run_id=config.get("run_id"), + name=config.get("run_name"), + ) + for cm, input, config in zip(callback_managers, inputs, configs) ) ) @@ -783,7 +783,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): config = ensure_config(config) callback_manager = get_callback_manager_for_config(config) # start the root run - run_manager = callback_manager.on_chain_start(dumpd(self), input) + run_manager = callback_manager.on_chain_start( + dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name") + ) # invoke all steps in sequence try: @@ -811,7 +813,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): config = ensure_config(config) callback_manager = get_async_callback_manager_for_config(config) # start the root run - run_manager = await callback_manager.on_chain_start(dumpd(self), input) + run_manager = await callback_manager.on_chain_start( + dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name") + ) # invoke all steps in sequence try: @@ -838,7 +842,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): from langchain.callbacks.manager import CallbackManager # setup callbacks - configs = self._get_config_list(config, len(inputs)) + configs = get_config_list(config, len(inputs)) callback_managers = [ CallbackManager.configure( inheritable_callbacks=config.get("callbacks"), @@ -853,8 +857,13 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): ] # start the root runs, one per input run_managers = [ - cm.on_chain_start(dumpd(self), input) - for cm, input in zip(callback_managers, inputs) + cm.on_chain_start( + dumpd(self), + input, + run_id=config.get("run_id"), + name=config.get("run_name"), + ) + for cm, input, config in zip(callback_managers, inputs, configs) ] # invoke @@ -889,7 +898,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): ) # setup callbacks - configs = self._get_config_list(config, len(inputs)) + configs = get_config_list(config, len(inputs)) callback_managers = [ AsyncCallbackManager.configure( inheritable_callbacks=config.get("callbacks"), @@ -905,8 +914,13 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): # start the root runs, one per input run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( *( - cm.on_chain_start(dumpd(self), input) - for cm, input in zip(callback_managers, inputs) + cm.on_chain_start( + dumpd(self), + input, + run_id=config.get("run_id"), + name=config.get("run_name"), + ) + for cm, input, config in zip(callback_managers, inputs, configs) ) ) @@ -942,7 +956,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): config = ensure_config(config) callback_manager = get_callback_manager_for_config(config) # start the root run - run_manager = callback_manager.on_chain_start(dumpd(self), input) + run_manager = callback_manager.on_chain_start( + dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name") + ) steps = [self.first] + self.middle + [self.last] streaming_start_index = 0 @@ -1009,7 +1025,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): config = ensure_config(config) callback_manager = get_async_callback_manager_for_config(config) # start the root run - run_manager = await callback_manager.on_chain_start(dumpd(self), input) + run_manager = await callback_manager.on_chain_start( + dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name") + ) steps = [self.first] + self.middle + [self.last] streaming_start_index = len(steps) - 1 @@ -1140,7 +1158,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): local_metadata=None, ) # start the root run - run_manager = callback_manager.on_chain_start(dumpd(self), input) + run_manager = callback_manager.on_chain_start( + dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name") + ) # gather results from all steps try: @@ -1179,7 +1199,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): config = ensure_config(config) callback_manager = get_async_callback_manager_for_config(config) # start the root run - run_manager = await callback_manager.on_chain_start(dumpd(self), input) + run_manager = await callback_manager.on_chain_start( + dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name") + ) # gather results from all steps try: @@ -1529,7 +1551,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): **kwargs: Optional[Any], ) -> Output: return self.bound.invoke( - input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs} + input, + cast(RunnableConfig, {**self.config, **(config or {})}), + **{**self.kwargs, **kwargs}, ) async def ainvoke( @@ -1539,7 +1563,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): **kwargs: Optional[Any], ) -> Output: return await self.bound.ainvoke( - input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs} + input, + cast(RunnableConfig, {**self.config, **(config or {})}), + **{**self.kwargs, **kwargs}, ) def batch( @@ -1548,13 +1574,17 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, **kwargs: Optional[Any], ) -> List[Output]: - configs = ( + configs = cast( + List[RunnableConfig], [{**self.config, **(conf or {})} for conf in config] if isinstance(config, list) else [ - patch_config({**self.config, **(config or {})}, deep_copy_locals=True) + patch_config( + cast(RunnableConfig, {**self.config, **(config or {})}), + deep_copy_locals=True, + ) for _ in range(len(inputs)) - ] + ], ) return self.bound.batch(inputs, configs, **{**self.kwargs, **kwargs}) @@ -1564,19 +1594,19 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, **kwargs: Optional[Any], ) -> List[Output]: - configs = ( + configs = cast( + List[RunnableConfig], [{**self.config, **(conf or {})} for conf in config] if isinstance(config, list) else [ - patch_config({**self.config, **(config or {})}, deep_copy_locals=True) + patch_config( + cast(RunnableConfig, {**self.config, **(config or {})}), + deep_copy_locals=True, + ) for _ in range(len(inputs)) - ] - ) - return await self.bound.abatch( - inputs, - [{**self.config, **(conf or {})} for conf in configs], - **{**self.kwargs, **kwargs}, + ], ) + return await self.bound.abatch(inputs, configs, **{**self.kwargs, **kwargs}) def stream( self, @@ -1585,7 +1615,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): **kwargs: Optional[Any], ) -> Iterator[Output]: yield from self.bound.stream( - input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs} + input, + cast(RunnableConfig, {**self.config, **(config or {})}), + **{**self.kwargs, **kwargs}, ) async def astream( @@ -1595,7 +1627,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): **kwargs: Optional[Any], ) -> AsyncIterator[Output]: async for item in self.bound.astream( - input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs} + input, + cast(RunnableConfig, {**self.config, **(config or {})}), + **{**self.kwargs, **kwargs}, ): yield item @@ -1606,7 +1640,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): **kwargs: Any, ) -> Iterator[Output]: yield from self.bound.transform( - input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs} + input, + cast(RunnableConfig, {**self.config, **(config or {})}), + **{**self.kwargs, **kwargs}, ) async def atransform( @@ -1616,7 +1652,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): **kwargs: Any, ) -> AsyncIterator[Output]: async for item in self.bound.atransform( - input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs} + input, + cast(RunnableConfig, {**self.config, **(config or {})}), + **{**self.kwargs, **kwargs}, ): yield item diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index 10cf15c0d88..c7befd5414b 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -3,7 +3,9 @@ from __future__ import annotations from concurrent.futures import Executor, ThreadPoolExecutor from contextlib import contextmanager from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union +from uuid import UUID + from typing_extensions import TypedDict if TYPE_CHECKING: @@ -32,6 +34,16 @@ class RunnableConfig(TypedDict, total=False): Tags are passed to all callbacks, metadata is passed to handle*Start callbacks. """ + run_name: str + """ + Name for the tracer run for this call. Defaults to the name of the class. + """ + + run_id: UUID + """ + Unique ID for the tracer run for this call. Defaults to uuid4(). + """ + _locals: Dict[str, Any] """ Local variables @@ -62,6 +74,28 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: return empty +def get_config_list( + config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int +) -> List[RunnableConfig]: + """ + Helper method to get a list of configs from a single config or a list of + configs, useful for subclasses overriding batch() or abatch(). + """ + if length < 1: + raise ValueError(f"length must be >= 1, but got {length}") + if isinstance(config, list) and len(config) != length: + raise ValueError( + f"config must be a list of the same length as inputs, " + f"but got {len(config)} configs for {length} inputs" + ) + + return ( + list(map(ensure_config, config)) + if isinstance(config, list) + else [patch_config(config, deep_copy_locals=True) for _ in range(length)] + ) + + def patch_config( config: Optional[RunnableConfig], *, diff --git a/libs/langchain/langchain/schema/runnable/router.py b/libs/langchain/langchain/schema/runnable/router.py index 68989bfa7d5..52779325435 100644 --- a/libs/langchain/langchain/schema/runnable/router.py +++ b/libs/langchain/langchain/schema/runnable/router.py @@ -23,7 +23,7 @@ from langchain.schema.runnable.base import ( RunnableSequence, coerce_to_runnable, ) -from langchain.schema.runnable.config import RunnableConfig +from langchain.schema.runnable.config import RunnableConfig, get_config_list from langchain.schema.runnable.utils import gather_with_concurrency @@ -131,7 +131,7 @@ class RouterRunnable( raise ValueError("One or more keys do not have a corresponding runnable") runnables = [self.runnables[key] for key in keys] - configs = self._get_config_list(config, len(inputs)) + configs = get_config_list(config, len(inputs)) with ThreadPoolExecutor(max_workers=max_concurrency) as executor: return list( executor.map( @@ -156,7 +156,7 @@ class RouterRunnable( raise ValueError("One or more keys do not have a corresponding runnable") runnables = [self.runnables[key] for key in keys] - configs = self._get_config_list(config, len(inputs)) + configs = get_config_list(config, len(inputs)) return await gather_with_concurrency( max_concurrency, *( diff --git a/libs/langchain/langchain/tools/base.py b/libs/langchain/langchain/tools/base.py index 9ad81033d5d..eaedc5f8d05 100644 --- a/libs/langchain/langchain/tools/base.py +++ b/libs/langchain/langchain/tools/base.py @@ -8,6 +8,7 @@ from abc import abstractmethod from functools import partial from inspect import signature from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union +from uuid import UUID from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import ( @@ -297,6 +298,7 @@ class ChildTool(BaseTool): *, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, + run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: """Run the tool.""" @@ -320,6 +322,7 @@ class ChildTool(BaseTool): {"name": self.name, "description": self.description}, tool_input if isinstance(tool_input, str) else str(tool_input), color=start_color, + run_id=run_id, **kwargs, ) try: @@ -370,6 +373,7 @@ class ChildTool(BaseTool): *, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, + run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: """Run the tool asynchronously.""" @@ -392,6 +396,7 @@ class ChildTool(BaseTool): {"name": self.name, "description": self.description}, tool_input if isinstance(tool_input, str) else str(tool_input), color=start_color, + run_id=run_id, **kwargs, ) try: From f95bd0bcd9b3581fe5e128ff10c7bccbee8bbad3 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 24 Aug 2023 16:21:31 +0200 Subject: [PATCH 03/12] Fix issue --- libs/langchain/langchain/schema/runnable/config.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index c7befd5414b..f12f1f83f3f 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -107,7 +107,11 @@ def patch_config( if deep_copy_locals: config["_locals"] = deepcopy(config["_locals"]) if callbacks is not None: + # If we're replacing callbacks we need to unset run_name and run_id + # As those should apply only to the same run as the original callbacks config["callbacks"] = callbacks + config["run_name"] = None + config["run_id"] = None if recursion_limit is not None: config["recursion_limit"] = recursion_limit return config From 542671231172ca66ead99506d1cdd7a952a5cfb4 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 24 Aug 2023 16:49:14 +0200 Subject: [PATCH 04/12] Adjust merge logic --- .../langchain/schema/runnable/base.py | 33 ++++++++++--------- .../schema/runnable/test_runnable.py | 25 ++++++++++++-- 2 files changed, 40 insertions(+), 18 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 96878c6969c..abd7ae81c60 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -1527,6 +1527,13 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): def lc_namespace(self) -> List[str]: return self.__class__.__module__.split(".")[:-1] + def _merge_config(self, config: Optional[RunnableConfig]) -> RunnableConfig: + copy = cast(RunnableConfig, dict(self.config)) + if config: + for key in config: + copy[key] = config[key] or copy.get(key) + return copy + def bind(self, **kwargs: Any) -> Runnable[Input, Output]: return self.__class__( bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs} @@ -1552,7 +1559,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): ) -> Output: return self.bound.invoke( input, - cast(RunnableConfig, {**self.config, **(config or {})}), + self._merge_config(config), **{**self.kwargs, **kwargs}, ) @@ -1564,7 +1571,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): ) -> Output: return await self.bound.ainvoke( input, - cast(RunnableConfig, {**self.config, **(config or {})}), + self._merge_config(config), **{**self.kwargs, **kwargs}, ) @@ -1576,13 +1583,10 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): ) -> List[Output]: configs = cast( List[RunnableConfig], - [{**self.config, **(conf or {})} for conf in config] + [self._merge_config(conf) for conf in config] if isinstance(config, list) else [ - patch_config( - cast(RunnableConfig, {**self.config, **(config or {})}), - deep_copy_locals=True, - ) + patch_config(self._merge_config(config), deep_copy_locals=True) for _ in range(len(inputs)) ], ) @@ -1596,13 +1600,10 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): ) -> List[Output]: configs = cast( List[RunnableConfig], - [{**self.config, **(conf or {})} for conf in config] + [self._merge_config(conf) for conf in config] if isinstance(config, list) else [ - patch_config( - cast(RunnableConfig, {**self.config, **(config or {})}), - deep_copy_locals=True, - ) + patch_config(self._merge_config(config), deep_copy_locals=True) for _ in range(len(inputs)) ], ) @@ -1616,7 +1617,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): ) -> Iterator[Output]: yield from self.bound.stream( input, - cast(RunnableConfig, {**self.config, **(config or {})}), + self._merge_config(config), **{**self.kwargs, **kwargs}, ) @@ -1628,7 +1629,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): ) -> AsyncIterator[Output]: async for item in self.bound.astream( input, - cast(RunnableConfig, {**self.config, **(config or {})}), + self._merge_config(config), **{**self.kwargs, **kwargs}, ): yield item @@ -1641,7 +1642,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): ) -> Iterator[Output]: yield from self.bound.transform( input, - cast(RunnableConfig, {**self.config, **(config or {})}), + self._merge_config(config), **{**self.kwargs, **kwargs}, ) @@ -1653,7 +1654,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): ) -> AsyncIterator[Output]: async for item in self.bound.atransform( input, - cast(RunnableConfig, {**self.config, **(config or {})}), + self._merge_config(config), **{**self.kwargs, **kwargs}, ): yield item diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 32f488e10c9..6dde96529ca 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -1,6 +1,7 @@ from operator import itemgetter from typing import Any, Dict, List, Optional, Union from uuid import UUID +from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler import pytest from freezegun import freeze_time @@ -123,6 +124,25 @@ async def test_with_config(mocker: MockerFixture) -> None: ] spy.reset_mock() + fake_1 = RunnablePassthrough() + fake_2 = RunnablePassthrough() + spy_seq_step = mocker.spy(fake_1.__class__, "invoke") + + sequence = fake_1.with_config(tags=["a-tag"]) | fake_2.with_config( + tags=["b-tag"], max_concurrency=5 + ) + assert sequence.invoke("hello") == "hello" + assert len(spy_seq_step.call_args_list) == 2 + for i, call in enumerate(spy_seq_step.call_args_list): + assert call.args[1] == "hello" + if i == 0: + assert call.args[2].get("tags") == ["a-tag"] + assert call.args[2].get("max_concurrency") is None + else: + assert call.args[2].get("tags") == ["b-tag"] + assert call.args[2].get("max_concurrency") == 5 + spy_seq_step.reset_mock() + assert [ *fake.with_config(tags=["a-tag"]).stream( "hello", dict(metadata={"key": "value"}) @@ -161,14 +181,15 @@ async def test_with_config(mocker: MockerFixture) -> None: assert call.args[1].get("metadata") == {"a": "b"} spy.reset_mock() + handler = ConsoleCallbackHandler() assert ( await fake.with_config(metadata={"a": "b"}).ainvoke( - "hello", config={"callbacks": []} + "hello", config={"callbacks": [handler]} ) == 5 ) assert spy.call_args_list == [ - mocker.call("hello", dict(callbacks=[], metadata={"a": "b"})), + mocker.call("hello", dict(callbacks=[handler], metadata={"a": "b"})), ] spy.reset_mock() From 9a070320550db26a5145ff82b7d842e8e04466aa Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 24 Aug 2023 16:54:20 +0200 Subject: [PATCH 05/12] Lint --- .../langchain/tests/unit_tests/schema/runnable/test_runnable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 6dde96529ca..849031f875e 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -1,7 +1,6 @@ from operator import itemgetter from typing import Any, Dict, List, Optional, Union from uuid import UUID -from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler import pytest from freezegun import freeze_time @@ -12,6 +11,7 @@ from langchain import PromptTemplate from langchain.callbacks.manager import Callbacks from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.schemas import Run +from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler from langchain.chat_models.fake import FakeListChatModel from langchain.llms.fake import FakeListLLM, FakeStreamingListLLM from langchain.load.dump import dumpd, dumps From 738d93215ddf8fe03cec6afd96b50cf6c99e635f Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 24 Aug 2023 17:20:49 +0200 Subject: [PATCH 06/12] Allow patching run_name and max_concurrency --- libs/langchain/langchain/schema/runnable/config.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index f12f1f83f3f..e2274c134e0 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -102,6 +102,8 @@ def patch_config( deep_copy_locals: bool = False, callbacks: Optional[BaseCallbackManager] = None, recursion_limit: Optional[int] = None, + max_concurrency: Optional[int] = None, + run_name: Optional[str] = None, ) -> RunnableConfig: config = ensure_config(config) if deep_copy_locals: @@ -114,6 +116,10 @@ def patch_config( config["run_id"] = None if recursion_limit is not None: config["recursion_limit"] = recursion_limit + if max_concurrency is not None: + config["max_concurrency"] = max_concurrency + if run_name is not None: + config["run_name"] = run_name return config From 06e89c1caa5249dbe4025e904bc6c813ec2d6d23 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 24 Aug 2023 18:22:43 +0200 Subject: [PATCH 07/12] Lint --- libs/langchain/langchain/schema/runnable/base.py | 4 +++- libs/langchain/langchain/schema/runnable/config.py | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index abd7ae81c60..1b8734ac717 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -1531,7 +1531,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): copy = cast(RunnableConfig, dict(self.config)) if config: for key in config: - copy[key] = config[key] or copy.get(key) + # Even though the keys aren't literals this is correct + # because both dicts are same type + copy[key] = config[key] or copy.get(key) # type: ignore return copy def bind(self, **kwargs: Any) -> Runnable[Input, Output]: diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index e2274c134e0..d31c1f67bda 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -112,8 +112,10 @@ def patch_config( # If we're replacing callbacks we need to unset run_name and run_id # As those should apply only to the same run as the original callbacks config["callbacks"] = callbacks - config["run_name"] = None - config["run_id"] = None + if "run_name" in config: + del config["run_name"] + if "run_id" in config: + del config["run_id"] if recursion_limit is not None: config["recursion_limit"] = recursion_limit if max_concurrency is not None: From f9a845b382bb6e783ef86063154901d9823eaf8d Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 24 Aug 2023 18:25:42 +0200 Subject: [PATCH 08/12] Lint --- .../tests/unit_tests/schema/runnable/test_runnable.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 849031f875e..412fa8e1e7f 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -124,8 +124,8 @@ async def test_with_config(mocker: MockerFixture) -> None: ] spy.reset_mock() - fake_1 = RunnablePassthrough() - fake_2 = RunnablePassthrough() + fake_1: Runnable = RunnablePassthrough() + fake_2: Runnable = RunnablePassthrough() spy_seq_step = mocker.spy(fake_1.__class__, "invoke") sequence = fake_1.with_config(tags=["a-tag"]) | fake_2.with_config( From 4d7cd6db5fcad2c1280b19acd38dba605279c696 Mon Sep 17 00:00:00 2001 From: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Date: Fri, 25 Aug 2023 13:24:44 -0700 Subject: [PATCH 09/12] add cm --- libs/langchain/langchain/chains/base.py | 7 ------ libs/langchain/langchain/schema/retriever.py | 5 ---- .../langchain/schema/runnable/base.py | 24 +++++++------------ .../langchain/schema/runnable/config.py | 6 ----- libs/langchain/langchain/tools/base.py | 5 ---- 5 files changed, 8 insertions(+), 39 deletions(-) diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 348ce9527b9..848d0940a6e 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -8,7 +8,6 @@ from abc import ABC, abstractmethod from functools import partial from pathlib import Path from typing import Any, Dict, List, Optional, Union -from uuid import UUID import yaml @@ -69,7 +68,6 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): callbacks=config.get("callbacks"), tags=config.get("tags"), metadata=config.get("metadata"), - run_id=config.get("run_id"), run_name=config.get("run_name"), **kwargs, ) @@ -92,7 +90,6 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): callbacks=config.get("callbacks"), tags=config.get("tags"), metadata=config.get("metadata"), - run_id=config.get("run_id"), run_name=config.get("run_name"), **kwargs, ) @@ -240,7 +237,6 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): *, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, - run_id: Optional[UUID] = None, run_name: Optional[str] = None, include_run_info: bool = False, ) -> Dict[str, Any]: @@ -283,7 +279,6 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): run_manager = callback_manager.on_chain_start( dumpd(self), inputs, - run_id=run_id, name=run_name, ) try: @@ -311,7 +306,6 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): *, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, - run_id: Optional[UUID] = None, run_name: Optional[str] = None, include_run_info: bool = False, ) -> Dict[str, Any]: @@ -354,7 +348,6 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): run_manager = await callback_manager.on_chain_start( dumpd(self), inputs, - run_id=run_id, name=run_name, ) try: diff --git a/libs/langchain/langchain/schema/retriever.py b/libs/langchain/langchain/schema/retriever.py index ba522172fb5..5da50e1497e 100644 --- a/libs/langchain/langchain/schema/retriever.py +++ b/libs/langchain/langchain/schema/retriever.py @@ -4,7 +4,6 @@ import warnings from abc import ABC, abstractmethod from inspect import signature from typing import TYPE_CHECKING, Any, Dict, List, Optional -from uuid import UUID from langchain.load.dump import dumpd from langchain.load.serializable import Serializable @@ -165,7 +164,6 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): callbacks: Callbacks = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, - run_id: Optional[UUID] = None, **kwargs: Any, ) -> List[Document]: """Retrieve documents relevant to a query. @@ -195,7 +193,6 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): run_manager = callback_manager.on_retriever_start( dumpd(self), query, - run_id=run_id, **kwargs, ) try: @@ -223,7 +220,6 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): callbacks: Callbacks = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, - run_id: Optional[UUID] = None, **kwargs: Any, ) -> List[Document]: """Asynchronously get documents relevant to a query. @@ -253,7 +249,6 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): run_manager = await callback_manager.on_retriever_start( dumpd(self), query, - run_id=run_id, **kwargs, ) try: diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 1b8734ac717..d3e8d9e85d8 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -266,7 +266,6 @@ class Runnable(Generic[Input, Output], ABC): dumpd(self), input, run_type=run_type, - run_id=config.get("run_id"), name=config.get("run_name"), ) try: @@ -309,7 +308,6 @@ class Runnable(Generic[Input, Output], ABC): dumpd(self), input, run_type=run_type, - run_id=config.get("run_id"), name=config.get("run_name"), ) try: @@ -368,7 +366,6 @@ class Runnable(Generic[Input, Output], ABC): dumpd(self), {"input": ""}, run_type=run_type, - run_id=config.get("run_id"), name=config.get("run_name"), ) try: @@ -450,7 +447,6 @@ class Runnable(Generic[Input, Output], ABC): dumpd(self), {"input": ""}, run_type=run_type, - run_id=config.get("run_id"), name=config.get("run_name"), ) try: @@ -528,7 +524,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): callback_manager = get_callback_manager_for_config(config) # start the root run run_manager = callback_manager.on_chain_start( - dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name") + dumpd(self), input, name=config.get("run_name") ) first_error = None for runnable in self.runnables: @@ -562,7 +558,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( - dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name") + dumpd(self), input, name=config.get("run_name") ) first_error = None @@ -613,7 +609,6 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): cm.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input}, - run_id=config.get("run_id"), name=config.get("run_name"), ) for cm, input, config in zip(callback_managers, inputs, configs) @@ -675,7 +670,6 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): cm.on_chain_start( dumpd(self), input, - run_id=config.get("run_id"), name=config.get("run_name"), ) for cm, input, config in zip(callback_managers, inputs, configs) @@ -784,7 +778,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): callback_manager = get_callback_manager_for_config(config) # start the root run run_manager = callback_manager.on_chain_start( - dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name") + dumpd(self), input, name=config.get("run_name") ) # invoke all steps in sequence @@ -814,7 +808,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( - dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name") + dumpd(self), input, name=config.get("run_name") ) # invoke all steps in sequence @@ -860,7 +854,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): cm.on_chain_start( dumpd(self), input, - run_id=config.get("run_id"), name=config.get("run_name"), ) for cm, input, config in zip(callback_managers, inputs, configs) @@ -917,7 +910,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): cm.on_chain_start( dumpd(self), input, - run_id=config.get("run_id"), name=config.get("run_name"), ) for cm, input, config in zip(callback_managers, inputs, configs) @@ -957,7 +949,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): callback_manager = get_callback_manager_for_config(config) # start the root run run_manager = callback_manager.on_chain_start( - dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name") + dumpd(self), input, name=config.get("run_name") ) steps = [self.first] + self.middle + [self.last] @@ -1026,7 +1018,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( - dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name") + dumpd(self), input, name=config.get("run_name") ) steps = [self.first] + self.middle + [self.last] @@ -1159,7 +1151,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): ) # start the root run run_manager = callback_manager.on_chain_start( - dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name") + dumpd(self), input, name=config.get("run_name") ) # gather results from all steps @@ -1200,7 +1192,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( - dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name") + dumpd(self), input, name=config.get("run_name") ) # gather results from all steps diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index d31c1f67bda..74ada6cecbf 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -4,7 +4,6 @@ from concurrent.futures import Executor, ThreadPoolExecutor from contextlib import contextmanager from copy import deepcopy from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union -from uuid import UUID from typing_extensions import TypedDict @@ -39,11 +38,6 @@ class RunnableConfig(TypedDict, total=False): Name for the tracer run for this call. Defaults to the name of the class. """ - run_id: UUID - """ - Unique ID for the tracer run for this call. Defaults to uuid4(). - """ - _locals: Dict[str, Any] """ Local variables diff --git a/libs/langchain/langchain/tools/base.py b/libs/langchain/langchain/tools/base.py index eaedc5f8d05..9ad81033d5d 100644 --- a/libs/langchain/langchain/tools/base.py +++ b/libs/langchain/langchain/tools/base.py @@ -8,7 +8,6 @@ from abc import abstractmethod from functools import partial from inspect import signature from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union -from uuid import UUID from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import ( @@ -298,7 +297,6 @@ class ChildTool(BaseTool): *, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, - run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: """Run the tool.""" @@ -322,7 +320,6 @@ class ChildTool(BaseTool): {"name": self.name, "description": self.description}, tool_input if isinstance(tool_input, str) else str(tool_input), color=start_color, - run_id=run_id, **kwargs, ) try: @@ -373,7 +370,6 @@ class ChildTool(BaseTool): *, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, - run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: """Run the tool asynchronously.""" @@ -396,7 +392,6 @@ class ChildTool(BaseTool): {"name": self.name, "description": self.description}, tool_input if isinstance(tool_input, str) else str(tool_input), color=start_color, - run_id=run_id, **kwargs, ) try: From 897f791940400fed3fa59922e72010cbfda99f23 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Sat, 26 Aug 2023 08:01:10 +0200 Subject: [PATCH 10/12] Remove run_id from patch --- libs/langchain/langchain/schema/runnable/config.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index 74ada6cecbf..3f87f044039 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -103,13 +103,11 @@ def patch_config( if deep_copy_locals: config["_locals"] = deepcopy(config["_locals"]) if callbacks is not None: - # If we're replacing callbacks we need to unset run_name and run_id - # As those should apply only to the same run as the original callbacks + # If we're replacing callbacks we need to unset run_name + # As that should apply only to the same run as the original callbacks config["callbacks"] = callbacks if "run_name" in config: del config["run_name"] - if "run_id" in config: - del config["run_id"] if recursion_limit is not None: config["recursion_limit"] = recursion_limit if max_concurrency is not None: From fc42726ea07e79c6da052ffc32c0ed79aaa3a8d1 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 1 Sep 2023 15:03:48 +0100 Subject: [PATCH 11/12] Styling --- libs/langchain/langchain/schema/runnable/base.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index d3e8d9e85d8..c727ba676c1 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -1575,15 +1575,15 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, **kwargs: Optional[Any], ) -> List[Output]: - configs = cast( - List[RunnableConfig], - [self._merge_config(conf) for conf in config] - if isinstance(config, list) - else [ + if isinstance(config, list): + configs = cast( + List[RunnableConfig], [self._merge_config(conf) for conf in config] + ) + else: + configs = [ patch_config(self._merge_config(config), deep_copy_locals=True) for _ in range(len(inputs)) - ], - ) + ] return self.bound.batch(inputs, configs, **{**self.kwargs, **kwargs}) async def abatch( From 81ebcc161e31007b46d18be35ae3242c95b03a96 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 1 Sep 2023 15:46:53 +0100 Subject: [PATCH 12/12] Lint --- libs/langchain/langchain/schema/runnable/base.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index c727ba676c1..88572bfee16 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -1592,15 +1592,15 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, **kwargs: Optional[Any], ) -> List[Output]: - configs = cast( - List[RunnableConfig], - [self._merge_config(conf) for conf in config] - if isinstance(config, list) - else [ + if isinstance(config, list): + configs = cast( + List[RunnableConfig], [self._merge_config(conf) for conf in config] + ) + else: + configs = [ patch_config(self._merge_config(config), deep_copy_locals=True) for _ in range(len(inputs)) - ], - ) + ] return await self.bound.abatch(inputs, configs, **{**self.kwargs, **kwargs}) def stream(