From 50b13ab9384932b9dd4aef64dd08150dbb6a5655 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Wed, 9 Aug 2023 13:26:09 -0700 Subject: [PATCH 01/23] wip --- libs/langchain/langchain/chains/base.py | 8 +- libs/langchain/langchain/schema/runnable.py | 158 ++++++------------ .../smith/evaluation/runner_utils.py | 8 +- 3 files changed, 60 insertions(+), 114 deletions(-) diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 301b0143e7b..a490c583158 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -62,7 +62,9 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Dict[str, Any]: - return self(input, **(config or {}), **kwargs) + _config: Dict[str, Any] = dict(config) if config else {} + _config.pop("_locals", None) + return self(input, **_config, **kwargs) async def ainvoke( self, @@ -76,7 +78,9 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): None, partial(self.invoke, input, config, **kwargs) ) - return await self.acall(input, **(config or {}), **kwargs) + _config: Dict[str, Any] = dict(config) if config else {} + _config.pop("_locals", None) + return await self.acall(input, **_config, **kwargs) memory: Optional[BaseMemory] = None """Optional memory object. Defaults to None. diff --git a/libs/langchain/langchain/schema/runnable.py b/libs/langchain/langchain/schema/runnable.py index 8edafe4599e..eebd5a96aa7 100644 --- a/libs/langchain/langchain/schema/runnable.py +++ b/libs/langchain/langchain/schema/runnable.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor +from copy import deepcopy from itertools import tee from typing import ( Any, @@ -66,6 +67,35 @@ class RunnableConfig(TypedDict, total=False): Tags are passed to all callbacks, metadata is passed to handle*Start callbacks. """ + _locals: Dict[str, Any] + """ + Local variables + """ + + +def _empty_config() -> RunnableConfig: + return RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={}) + + +def _get_callback_manager(config: Mapping) -> Any: + from langchain.callbacks.manager import CallbackManager + + return CallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + inheritable_tags=config.get("tags"), + inheritable_metadata=config.get("metadata"), + ) + + +def _get_async_callback_manager(config: Mapping) -> Any: + from langchain.callbacks.manager import AsyncCallbackManager + + return AsyncCallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + inheritable_tags=config.get("tags"), + inheritable_metadata=config.get("metadata"), + ) + Input = TypeVar("Input") # Output type should implement __concat__, as eg str, list, dict do @@ -243,7 +273,7 @@ class Runnable(Generic[Input, Output], ABC): return ( config if isinstance(config, list) - else [config.copy() if config is not None else {} for _ in range(length)] + else [deepcopy(config) if config is not None else {} for _ in range(length)] ) def _call_with_config( @@ -255,14 +285,8 @@ class Runnable(Generic[Input, Output], ABC): ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement invoke() in subclasses.""" - from langchain.callbacks.manager import CallbackManager - config = config or {} - callback_manager = CallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - inheritable_tags=config.get("tags"), - inheritable_metadata=config.get("metadata"), - ) + callback_manager = _get_callback_manager(config) run_manager = callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input}, @@ -288,14 +312,8 @@ class Runnable(Generic[Input, Output], ABC): ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement ainvoke() in subclasses.""" - from langchain.callbacks.manager import AsyncCallbackManager - config = config or {} - callback_manager = AsyncCallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - inheritable_tags=config.get("tags"), - inheritable_metadata=config.get("metadata"), - ) + callback_manager = _get_async_callback_manager(config) run_manager = await callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input}, @@ -322,8 +340,6 @@ class Runnable(Generic[Input, Output], ABC): """Helper method to transform an Iterator of Input values into an Iterator of Output values, with callbacks. Use this to implement `stream()` or `transform()` in Runnable subclasses.""" - from langchain.callbacks.manager import CallbackManager - # tee the input so we can iterate over it twice input_for_tracing, input_for_transform = tee(input, 2) # Start the input iterator to ensure the input runnable starts before this one @@ -333,11 +349,7 @@ class Runnable(Generic[Input, Output], ABC): final_output_supported = True config = config or {} - callback_manager = CallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - inheritable_tags=config.get("tags"), - inheritable_metadata=config.get("metadata"), - ) + callback_manager = _get_callback_manager(config) run_manager = callback_manager.on_chain_start( dumpd(self), {"input": ""}, @@ -393,8 +405,6 @@ class Runnable(Generic[Input, Output], ABC): """Helper method to transform an Async Iterator of Input values into an Async Iterator of Output values, with callbacks. Use this to implement `astream()` or `atransform()` in Runnable subclasses.""" - from langchain.callbacks.manager import AsyncCallbackManager - # tee the input so we can iterate over it twice input_for_tracing, input_for_transform = atee(input, 2) # Start the input iterator to ensure the input runnable starts before this one @@ -404,11 +414,7 @@ class Runnable(Generic[Input, Output], ABC): final_output_supported = True config = config or {} - callback_manager = AsyncCallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - inheritable_tags=config.get("tags"), - inheritable_metadata=config.get("metadata"), - ) + callback_manager = _get_async_callback_manager(config) run_manager = await callback_manager.on_chain_start( dumpd(self), {"input": ""}, @@ -473,19 +479,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): yield from self.fallbacks def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: - from langchain.callbacks.manager import CallbackManager - # setup callbacks config = config or {} - callback_manager = CallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) + callback_manager = _get_callback_manager(config) # start the root run run_manager = callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input} @@ -516,19 +512,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None ) -> Output: - from langchain.callbacks.manager import AsyncCallbackManager - # setup callbacks config = config or {} - callback_manager = AsyncCallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) + callback_manager = _get_async_callback_manager(config) # start the root run run_manager = await callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input} @@ -751,19 +737,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): ) def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: - from langchain.callbacks.manager import CallbackManager - # setup callbacks - config = config or {} - callback_manager = CallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) + config = config or _empty_config() + callback_manager = _get_callback_manager(config) # start the root run run_manager = callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input} @@ -771,11 +747,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): # invoke all steps in sequence try: + callbacks = run_manager.get_child() for step in self.steps: input = step.invoke( input, # mark each step as a child run - _patch_config(config, run_manager.get_child()), + _patch_config(config, callbacks), ) # finish the root run except (KeyboardInterrupt, Exception) as e: @@ -790,19 +767,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None ) -> Output: - from langchain.callbacks.manager import AsyncCallbackManager - # setup callbacks config = config or {} - callback_manager = AsyncCallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) + callback_manager = _get_async_callback_manager(config) # start the root run run_manager = await callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input} @@ -946,19 +913,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): def stream( self, input: Input, config: Optional[RunnableConfig] = None ) -> Iterator[Output]: - from langchain.callbacks.manager import CallbackManager - # setup callbacks config = config or {} - callback_manager = CallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) + callback_manager = _get_callback_manager(config) # start the root run run_manager = callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input} @@ -1023,19 +980,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): async def astream( self, input: Input, config: Optional[RunnableConfig] = None ) -> AsyncIterator[Output]: - from langchain.callbacks.manager import AsyncCallbackManager - # setup callbacks config = config or {} - callback_manager = AsyncCallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) + callback_manager = _get_async_callback_manager(config) # start the root run run_manager = await callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input} @@ -1173,19 +1120,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None ) -> Dict[str, Any]: - from langchain.callbacks.manager import AsyncCallbackManager - # setup callbacks config = config or {} - callback_manager = AsyncCallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) + callback_manager = _get_async_callback_manager(config) # start the root run run_manager = await callback_manager.on_chain_start( dumpd(self), {"input": input} @@ -1464,10 +1401,11 @@ class RouterRunnable( def _patch_config( - config: RunnableConfig, callback_manager: BaseCallbackManager + config: RunnableConfig, callback_manager: BaseCallbackManager, _locals: Optional[Dict[str, Any]] = None ) -> RunnableConfig: - config = config.copy() + config = deepcopy(config) config["callbacks"] = callback_manager + config["_locals"] = _locals or {} return config diff --git a/libs/langchain/langchain/smith/evaluation/runner_utils.py b/libs/langchain/langchain/smith/evaluation/runner_utils.py index 5b3d5775c49..be55f6f99a1 100644 --- a/libs/langchain/langchain/smith/evaluation/runner_utils.py +++ b/libs/langchain/langchain/smith/evaluation/runner_utils.py @@ -636,7 +636,9 @@ async def _arun_chain( else: output = await chain.acall(inputs_, callbacks=callbacks, tags=tags) else: - runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks) + runnable_config = RunnableConfig( + tags=tags or [], callbacks=callbacks, _locals={} + ) output = await chain.ainvoke(inputs_, config=runnable_config) return output @@ -957,7 +959,9 @@ def _run_chain( else: output = chain(inputs_, callbacks=callbacks, tags=tags) else: - runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks) + runnable_config = RunnableConfig( + tags=tags or [], callbacks=callbacks, _locals={} + ) output = chain.invoke(inputs_, config=runnable_config) return output From eb0134fbb3c728fb5c9180384276315f6318497b Mon Sep 17 00:00:00 2001 From: Bagatur Date: Wed, 9 Aug 2023 14:13:06 -0700 Subject: [PATCH 02/23] rfc --- libs/langchain/langchain/chat_models/base.py | 8 ++- libs/langchain/langchain/llms/base.py | 15 +++--- libs/langchain/langchain/schema/retriever.py | 8 ++- libs/langchain/langchain/schema/runnable.py | 57 +++++++++++++++++--- 4 files changed, 69 insertions(+), 19 deletions(-) diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index b06b99f99d7..3d343274f5a 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -103,12 +103,14 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): stop: Optional[List[str]] = None, **kwargs: Any, ) -> BaseMessageChunk: + _config: Dict[str, Any] = dict(config or {}) + _config.pop("_locals", None) return cast( BaseMessageChunk, cast( ChatGeneration, self.generate_prompt( - [self._convert_input(input)], stop=stop, **(config or {}), **kwargs + [self._convert_input(input)], stop=stop, **_config, **kwargs ).generations[0][0], ).message, ) @@ -127,8 +129,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): None, partial(self.invoke, input, config, stop=stop, **kwargs) ) + _config: Dict[str, Any] = dict(config or {}) + _config.pop("_locals", None) llm_result = await self.agenerate_prompt( - [self._convert_input(input)], stop=stop, **(config or {}), **kwargs + [self._convert_input(input)], stop=stop, **_config, **kwargs ) return cast( BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message diff --git a/libs/langchain/langchain/llms/base.py b/libs/langchain/langchain/llms/base.py index 7da494de78b..04422124743 100644 --- a/libs/langchain/langchain/llms/base.py +++ b/libs/langchain/langchain/llms/base.py @@ -219,13 +219,12 @@ class BaseLLM(BaseLanguageModel[str], ABC): stop: Optional[List[str]] = None, **kwargs: Any, ) -> str: - return ( - self.generate_prompt( - [self._convert_input(input)], stop=stop, **(config or {}), **kwargs - ) - .generations[0][0] - .text + _config: Dict[str, Any] = dict(config or {}) + _config.pop("_locals", None) + result = self.generate_prompt( + [self._convert_input(input)], stop=stop, **_config, **kwargs ) + return result.generations[0][0].text async def ainvoke( self, @@ -241,8 +240,10 @@ class BaseLLM(BaseLanguageModel[str], ABC): None, partial(self.invoke, input, config, stop=stop, **kwargs) ) + _config: Dict[str, Any] = dict(config or {}) + _config.pop("_locals", None) llm_result = await self.agenerate_prompt( - [self._convert_input(input)], stop=stop, **(config or {}), **kwargs + [self._convert_input(input)], stop=stop, **_config, **kwargs ) return llm_result.generations[0][0].text diff --git a/libs/langchain/langchain/schema/retriever.py b/libs/langchain/langchain/schema/retriever.py index 9df3e7a1389..538ae1ed1c1 100644 --- a/libs/langchain/langchain/schema/retriever.py +++ b/libs/langchain/langchain/schema/retriever.py @@ -107,7 +107,9 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): def invoke( self, input: str, config: Optional[RunnableConfig] = None ) -> List[Document]: - return self.get_relevant_documents(input, **(config or {})) + _config: Dict[str, Any] = dict(config or {}) + _config.pop("_locals", None) + return self.get_relevant_documents(input, **_config) async def ainvoke( self, input: str, config: Optional[RunnableConfig] = None @@ -116,7 +118,9 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): # If the retriever doesn't implement async, use default implementation return await super().ainvoke(input, config) - return await self.aget_relevant_documents(input, **(config or {})) + _config: Dict[str, Any] = dict(config or {}) + _config.pop("_locals", None) + return await self.aget_relevant_documents(input, **_config) @abstractmethod def _get_relevant_documents( diff --git a/libs/langchain/langchain/schema/runnable.py b/libs/langchain/langchain/schema/runnable.py index eebd5a96aa7..47679c8883c 100644 --- a/libs/langchain/langchain/schema/runnable.py +++ b/libs/langchain/langchain/schema/runnable.py @@ -674,6 +674,46 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): raise first_error +class PutLocalVar(Serializable, Runnable[Input, Input]): + key: Union[str, Dict[str, str]] + + def __init__(self, key: str, **kwargs: Any) -> None: + super().__init__(key=key, **kwargs) + + def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input: + if config is None: + raise ValueError( + "PutLocalVar should only be used in a RunnableSequence, and should " + "therefore always receive a non-null config." + ) + if isinstance(self.key, str): + config["_locals"][self.key] = input + else: + if not isinstance(input, Mapping): + raise ValueError + for get_key, put_key in self.key.items(): + config["_locals"][put_key] = input[get_key] + return self._call_with_config(lambda x: x, input, config) + + +class GetLocalVar(Serializable, Runnable[str, Any]): + key: str + passthrough_key: Optional[str] = None + + def __init__(self, key: str, **kwargs: Any) -> None: + super().__init__(key=key, **kwargs) + + def invoke(self, input: str, config: Optional[RunnableConfig] = None) -> Any: + if config is None: + raise ValueError( + "PutLocalVar should only be used in a RunnableSequence, and should " + "therefore always receive a non-null config." + ) + if self.passthrough_key is not None: + return {self.key: config["_locals"][self.key], self.passthrough_key: input} + return config["_locals"][self.key] + + class RunnableSequence(Serializable, Runnable[Input, Output]): """ A sequence of runnables, where the output of each is the input of the next. @@ -749,11 +789,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): try: callbacks = run_manager.get_child() for step in self.steps: - input = step.invoke( - input, - # mark each step as a child run - _patch_config(config, callbacks), - ) + # mark each step as child run + step_config = _patch_config(config, callbacks) + input = step.invoke(input, step_config) # finish the root run except (KeyboardInterrupt, Exception) as e: run_manager.on_chain_error(e) @@ -1401,11 +1439,14 @@ class RouterRunnable( def _patch_config( - config: RunnableConfig, callback_manager: BaseCallbackManager, _locals: Optional[Dict[str, Any]] = None + config: RunnableConfig, + callback_manager: BaseCallbackManager, + _locals: Optional[Dict[str, Any]] = None, ) -> RunnableConfig: - config = deepcopy(config) + config = config.copy() config["callbacks"] = callback_manager - config["_locals"] = _locals or {} + if _locals is not None: + config["_locals"] = _locals return config From 8c1a528c7150a4cc833cedb567b668c3ab17a745 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Thu, 17 Aug 2023 13:52:09 -0700 Subject: [PATCH 03/23] cr --- .../langchain/schema/runnable/base.py | 40 ------------------- .../langchain/schema/runnable/passthrough.py | 5 ++- 2 files changed, 3 insertions(+), 42 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 9354355fa22..ee3f7c11427 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -653,46 +653,6 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): raise first_error -class PutLocalVar(Serializable, Runnable[Input, Input]): - key: Union[str, Dict[str, str]] - - def __init__(self, key: str, **kwargs: Any) -> None: - super().__init__(key=key, **kwargs) - - def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input: - if config is None: - raise ValueError( - "PutLocalVar should only be used in a RunnableSequence, and should " - "therefore always receive a non-null config." - ) - if isinstance(self.key, str): - config["_locals"][self.key] = input - else: - if not isinstance(input, Mapping): - raise ValueError - for get_key, put_key in self.key.items(): - config["_locals"][put_key] = input[get_key] - return self._call_with_config(lambda x: x, input, config) - - -class GetLocalVar(Serializable, Runnable[str, Any]): - key: str - passthrough_key: Optional[str] = None - - def __init__(self, key: str, **kwargs: Any) -> None: - super().__init__(key=key, **kwargs) - - def invoke(self, input: str, config: Optional[RunnableConfig] = None) -> Any: - if config is None: - raise ValueError( - "PutLocalVar should only be used in a RunnableSequence, and should " - "therefore always receive a non-null config." - ) - if self.passthrough_key is not None: - return {self.key: config["_locals"][self.key], self.passthrough_key: input} - return config["_locals"][self.key] - - class RunnableSequence(Serializable, Runnable[Input, Output]): """ A sequence of runnables, where the output of each is the input of the next. diff --git a/libs/langchain/langchain/schema/runnable/passthrough.py b/libs/langchain/langchain/schema/runnable/passthrough.py index a97e708b64b..41a130aa735 100644 --- a/libs/langchain/langchain/schema/runnable/passthrough.py +++ b/libs/langchain/langchain/schema/runnable/passthrough.py @@ -41,7 +41,8 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]): ) -> Iterator[Input]: return self._transform_stream_with_config(input, identity, config) - def atransform( + async def atransform( self, input: AsyncIterator[Input], config: RunnableConfig | None = None ) -> AsyncIterator[Input]: - return self._atransform_stream_with_config(input, identity, config) + async for chunk in self._atransform_stream_with_config(input, identity, config): + yield chunk From bd80cad6dbd045e36afe4be4071d1ef612ff9ea9 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Thu, 17 Aug 2023 13:52:19 -0700 Subject: [PATCH 04/23] add --- .../langchain/schema/runnable/locals.py | 115 ++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 libs/langchain/langchain/schema/runnable/locals.py diff --git a/libs/langchain/langchain/schema/runnable/locals.py b/libs/langchain/langchain/schema/runnable/locals.py new file mode 100644 index 00000000000..53d8f5a2ce6 --- /dev/null +++ b/libs/langchain/langchain/schema/runnable/locals.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union + +from langchain.load.serializable import Serializable +from langchain.schema.runnable import Runnable, RunnableConfig, RunnablePassthrough +from langchain.schema.runnable.base import Input, Output + + +class PutLocalVar(RunnablePassthrough): + key: Union[str, Mapping[str, str]] + """The key(s) to use for storing the input variable(s) in local state. + + If a string is provided then the entire input is stored under that key. If a + Mapping is provided, then the map values are gotten from the input and + stored in local state under the map keys. + """ + + def __init__(self, key: str, **kwargs: Any) -> None: + super().__init__(key=key, **kwargs) + + def _put(self, input: Input, *, config: Optional[RunnableConfig] = None) -> None: + if config is None: + raise ValueError( + "PutLocalVar should only be used in a RunnableSequence, and should " + "therefore always receive a non-null config." + ) + if isinstance(self.key, str): + config["_locals"][self.key] = input + elif isinstance(input, Mapping): + for input_key, put_key in self.key.items(): + config["_locals"][put_key] = input[input_key] + else: + raise TypeError( + f"`key` should be a string or Mapping[str, str], received type " + f"{(type(self.key))}." + ) + + def _concat_put( + self, input: Input, *, config: Optional[RunnableConfig] = None + ) -> None: + if config is None: + raise ValueError( + "PutLocalVar should only be used in a RunnableSequence, and should " + "therefore always receive a non-null config." + ) + if isinstance(self.key, str): + if self.key not in config["_locals"]: + config["_locals"][self.key] = input + else: + config["_locals"][self.key] += input + elif isinstance(input, Mapping): + for input_key, put_key in self.key.items(): + if put_key not in config["_locals"]: + config["_locals"][put_key] = input + else: + config["_locals"][put_key] += input + else: + raise TypeError( + f"`key` should be a string or Mapping[str, str], received type " + f"{(type(self.key))}." + ) + + def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input: + self._put(input, config=config) + return super().invoke(input, config) + + async def ainvoke( + self, input: Input, config: RunnableConfig | None = None + ) -> Input: + self._put(input, config=config) + return await super().ainvoke(input, config) + + def transform( + self, input: Iterator[Input], config: RunnableConfig | None = None + ) -> Iterator[Input]: + for chunk in super().transform(input, config=config): + self._concat_put(input, config=config) + yield chunk + + async def atransform( + self, input: AsyncIterator[Input], config: RunnableConfig | None = None + ) -> AsyncIterator[Input]: + async for chunk in super().atransform(input, config=config): + self._concat_put(input, config=config) + yield chunk + + +class GetLocalVar( + Serializable, Runnable[Input, Union[Output, Dict[str, Union[Input, Output]]]] +): + key: str + """The key to extract from the local state.""" + passthrough_key: Optional[str] = None + """The key to use for passing through the invocation input. + + If None, then only the value retrieved from local state is returned. Otherwise a + dictionary ``{self.key: <>, self.passthrough_key: <>}`` + is returned. + """ + + def __init__(self, key: str, **kwargs: Any) -> None: + super().__init__(key=key, **kwargs) + + def invoke( + self, input: Input, config: Optional[RunnableConfig] = None + ) -> Union[Output, Dict[str, Union[Input, Output]]]: + if config is None: + raise ValueError( + "PutLocalVar should only be used in a RunnableSequence, and should " + "therefore always receive a non-null config." + ) + if self.passthrough_key is not None: + return {self.key: config["_locals"][self.key], self.passthrough_key: input} + return config["_locals"][self.key] From c447e9a854deef90de62bad39991f1ea55a8f29b Mon Sep 17 00:00:00 2001 From: Bagatur Date: Thu, 17 Aug 2023 15:29:00 -0700 Subject: [PATCH 05/23] cr --- .../langchain/schema/runnable/base.py | 8 ++-- .../langchain/schema/runnable/locals.py | 40 +++++++++++++++---- 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index ee3f7c11427..3f3e90ba27f 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -243,8 +243,8 @@ class Runnable(Generic[Input, Output], ABC): def _call_with_config( self, - func: Callable[[Input], Output], - input: Input, + func: Callable[[Any], Output], + input: Any, config: Optional[RunnableConfig], run_type: Optional[str] = None, ) -> Output: @@ -273,8 +273,8 @@ class Runnable(Generic[Input, Output], ABC): async def _acall_with_config( self, - func: Callable[[Input], Awaitable[Output]], - input: Input, + func: Callable[[Any], Awaitable[Output]], + input: Any, config: Optional[RunnableConfig], run_type: Optional[str] = None, ) -> Output: diff --git a/libs/langchain/langchain/schema/runnable/locals.py b/libs/langchain/langchain/schema/runnable/locals.py index 53d8f5a2ce6..cf51336dc9c 100644 --- a/libs/langchain/langchain/schema/runnable/locals.py +++ b/libs/langchain/langchain/schema/runnable/locals.py @@ -16,7 +16,7 @@ class PutLocalVar(RunnablePassthrough): stored in local state under the map keys. """ - def __init__(self, key: str, **kwargs: Any) -> None: + def __init__(self, key: Union[str, Mapping[str, str]], **kwargs: Any) -> None: super().__init__(key=key, **kwargs) def _put(self, input: Input, *, config: Optional[RunnableConfig] = None) -> None: @@ -63,13 +63,13 @@ class PutLocalVar(RunnablePassthrough): def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input: self._put(input, config=config) - return super().invoke(input, config) + return super().invoke(input, config=config) async def ainvoke( self, input: Input, config: RunnableConfig | None = None ) -> Input: self._put(input, config=config) - return await super().ainvoke(input, config) + return await super().ainvoke(input, config=config) def transform( self, input: Iterator[Input], config: RunnableConfig | None = None @@ -102,14 +102,40 @@ class GetLocalVar( def __init__(self, key: str, **kwargs: Any) -> None: super().__init__(key=key, **kwargs) + def _get(self, full_input: Dict) -> Union[Output, Dict[str, Union[Input, Output]]]: + if self.passthrough_key: + return { + self.key: full_input["locals"][self.key], + self.passthrough_key: full_input["input"], + } + else: + return full_input["locals"][self.key] + + async def _aget( + self, full_input: Dict + ) -> Union[Output, Dict[str, Union[Input, Output]]]: + return self._get(full_input) + def invoke( self, input: Input, config: Optional[RunnableConfig] = None ) -> Union[Output, Dict[str, Union[Input, Output]]]: if config is None: raise ValueError( - "PutLocalVar should only be used in a RunnableSequence, and should " + "GetLocalVar should only be used in a RunnableSequence, and should " "therefore always receive a non-null config." ) - if self.passthrough_key is not None: - return {self.key: config["_locals"][self.key], self.passthrough_key: input} - return config["_locals"][self.key] + + log_input = {"input": input, "locals": config["_locals"]} + return self._call_with_config(self._get, log_input, config) + + async def ainvoke( + self, input: Input, config: Optional[RunnableConfig] = None + ) -> Union[Output, Dict[str, Union[Input, Output]]]: + if config is None: + raise ValueError( + "GetLocalVar should only be used in a RunnableSequence, and should " + "therefore always receive a non-null config." + ) + + log_input = {"input": input, "locals": config["_locals"]} + return await self._acall_with_config(self._aget, log_input, config) From 6b0a849f5953b05eab530cbceffe5ab6b44c3a72 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Thu, 17 Aug 2023 16:22:12 -0700 Subject: [PATCH 06/23] fix --- .../langchain/schema/runnable/__init__.py | 3 ++ .../langchain/schema/runnable/base.py | 11 +++++--- .../langchain/schema/runnable/locals.py | 28 +++++++++++++------ 3 files changed, 30 insertions(+), 12 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/__init__.py b/libs/langchain/langchain/schema/runnable/__init__.py index 0dbabd1579d..bae6aebb024 100644 --- a/libs/langchain/langchain/schema/runnable/__init__.py +++ b/libs/langchain/langchain/schema/runnable/__init__.py @@ -7,10 +7,13 @@ from langchain.schema.runnable.base import ( RunnableWithFallbacks, ) from langchain.schema.runnable.config import RunnableConfig +from langchain.schema.runnable.locals import GetLocalVar, PutLocalVar from langchain.schema.runnable.passthrough import RunnablePassthrough from langchain.schema.runnable.router import RouterInput, RouterRunnable __all__ = [ + "GetLocalVar", + "PutLocalVar", "RouterInput", "RouterRunnable", "Runnable", diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 3f3e90ba27f..704a518cde8 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -238,7 +238,10 @@ class Runnable(Generic[Input, Output], ABC): return ( config if isinstance(config, list) - else [deepcopy(config) if config is not None else {} for _ in range(length)] + else [ + deepcopy(config) if config is not None else _empty_config() + for _ in range(length) + ] ) def _call_with_config( @@ -750,7 +753,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None ) -> Output: # setup callbacks - config = config or {} + config = config or _empty_config() callback_manager = _get_async_callback_manager(config) # start the root run run_manager = await callback_manager.on_chain_start( @@ -896,7 +899,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None ) -> Iterator[Output]: # setup callbacks - config = config or {} + config = config or _empty_config() callback_manager = _get_callback_manager(config) # start the root run run_manager = callback_manager.on_chain_start( @@ -963,7 +966,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None ) -> AsyncIterator[Output]: # setup callbacks - config = config or {} + config = config or _empty_config() callback_manager = _get_async_callback_manager(config) # start the root run run_manager = await callback_manager.on_chain_start( diff --git a/libs/langchain/langchain/schema/runnable/locals.py b/libs/langchain/langchain/schema/runnable/locals.py index cf51336dc9c..65e63507bc4 100644 --- a/libs/langchain/langchain/schema/runnable/locals.py +++ b/libs/langchain/langchain/schema/runnable/locals.py @@ -3,8 +3,9 @@ from __future__ import annotations from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union from langchain.load.serializable import Serializable -from langchain.schema.runnable import Runnable, RunnableConfig, RunnablePassthrough -from langchain.schema.runnable.base import Input, Output +from langchain.schema.runnable.base import Input, Output, Runnable +from langchain.schema.runnable.config import RunnableConfig +from langchain.schema.runnable.passthrough import RunnablePassthrough class PutLocalVar(RunnablePassthrough): @@ -27,7 +28,12 @@ class PutLocalVar(RunnablePassthrough): ) if isinstance(self.key, str): config["_locals"][self.key] = input - elif isinstance(input, Mapping): + elif isinstance(self.key, Mapping): + if not isinstance(input, Mapping): + raise TypeError( + f"Received key of type Mapping but input of type {type(input)}. " + f"input is expected to be of type Mapping when key is Mapping." + ) for input_key, put_key in self.key.items(): config["_locals"][put_key] = input[input_key] else: @@ -44,17 +50,23 @@ class PutLocalVar(RunnablePassthrough): "PutLocalVar should only be used in a RunnableSequence, and should " "therefore always receive a non-null config." ) + print(config) if isinstance(self.key, str): if self.key not in config["_locals"]: config["_locals"][self.key] = input else: config["_locals"][self.key] += input - elif isinstance(input, Mapping): + elif isinstance(self.key, Mapping): + if not isinstance(input, Mapping): + raise TypeError( + f"Received key of type Mapping but input of type {type(input)}. " + f"input is expected to be of type Mapping when key is Mapping." + ) for input_key, put_key in self.key.items(): if put_key not in config["_locals"]: - config["_locals"][put_key] = input + config["_locals"][put_key] = input[input_key] else: - config["_locals"][put_key] += input + config["_locals"][put_key] += input[input_key] else: raise TypeError( f"`key` should be a string or Mapping[str, str], received type " @@ -75,14 +87,14 @@ class PutLocalVar(RunnablePassthrough): self, input: Iterator[Input], config: RunnableConfig | None = None ) -> Iterator[Input]: for chunk in super().transform(input, config=config): - self._concat_put(input, config=config) + self._concat_put(chunk, config=config) yield chunk async def atransform( self, input: AsyncIterator[Input], config: RunnableConfig | None = None ) -> AsyncIterator[Input]: async for chunk in super().atransform(input, config=config): - self._concat_put(input, config=config) + self._concat_put(chunk, config=config) yield chunk From 9e906c39ba974ae33d596174873a173b505648e9 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Thu, 17 Aug 2023 16:22:22 -0700 Subject: [PATCH 07/23] nit --- libs/langchain/langchain/schema/runnable/locals.py | 1 - 1 file changed, 1 deletion(-) diff --git a/libs/langchain/langchain/schema/runnable/locals.py b/libs/langchain/langchain/schema/runnable/locals.py index 65e63507bc4..5061dbf38c1 100644 --- a/libs/langchain/langchain/schema/runnable/locals.py +++ b/libs/langchain/langchain/schema/runnable/locals.py @@ -50,7 +50,6 @@ class PutLocalVar(RunnablePassthrough): "PutLocalVar should only be used in a RunnableSequence, and should " "therefore always receive a non-null config." ) - print(config) if isinstance(self.key, str): if self.key not in config["_locals"]: config["_locals"][self.key] = input From 6f69b19ff583a37387b8403f15fae7bfbcede4ba Mon Sep 17 00:00:00 2001 From: Bagatur Date: Thu, 17 Aug 2023 16:45:52 -0700 Subject: [PATCH 08/23] wip tests --- .../unit_tests/schema/runnable/__init__.py | 0 .../unit_tests/schema/runnable/test_locals.py | 31 +++++++++++++++++++ .../schema/{ => runnable}/test_runnable.py | 0 3 files changed, 31 insertions(+) create mode 100644 libs/langchain/tests/unit_tests/schema/runnable/__init__.py create mode 100644 libs/langchain/tests/unit_tests/schema/runnable/test_locals.py rename libs/langchain/tests/unit_tests/schema/{ => runnable}/test_runnable.py (100%) diff --git a/libs/langchain/tests/unit_tests/schema/runnable/__init__.py b/libs/langchain/tests/unit_tests/schema/runnable/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py b/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py new file mode 100644 index 00000000000..d0a3fb38d9d --- /dev/null +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py @@ -0,0 +1,31 @@ +import pytest + +from langchain.schema.runnable import GetLocalVar, PutLocalVar + + +@pytest.mark.asyncio +async def test_put_get() -> None: + runnable = PutLocalVar("input") | GetLocalVar("input") + assert runnable.invoke("foo") == "foo" + assert runnable.batch(["foo", "bar"]) == ["foo", "bar"] + assert list(runnable.stream("foo"))[0] == "foo" + + assert await runnable.ainvoke("foo") == "foo" + assert await runnable.abatch(["foo", "bar"]) == ["foo", "bar"] + async for x in runnable.astream("foo"): + assert x == "foo" + + +def test_missing_config() -> None: + with pytest.raises(ValueError): + PutLocalVar("input").invoke("foo") + + with pytest.raises(ValueError): + GetLocalVar("input").invoke("foo") + + +def test_get_missing_var_invoke() -> None: + runnable = PutLocalVar("input") | GetLocalVar("missing") + with pytest.raises(KeyError): + runnable.invoke("foo") + diff --git a/libs/langchain/tests/unit_tests/schema/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py similarity index 100% rename from libs/langchain/tests/unit_tests/schema/test_runnable.py rename to libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py From ab21af71be3c5a2fbe548061228df525c635ba86 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Thu, 17 Aug 2023 17:28:02 -0700 Subject: [PATCH 09/23] wip --- .../langchain/schema/runnable/base.py | 18 +++++----- .../unit_tests/schema/runnable/test_locals.py | 36 ++++++++++++++++++- 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 704a518cde8..c91456394e2 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -253,7 +253,7 @@ class Runnable(Generic[Input, Output], ABC): ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement invoke() in subclasses.""" - config = config or {} + config = config or _empty_config() callback_manager = _get_callback_manager(config) run_manager = callback_manager.on_chain_start( dumpd(self), @@ -283,7 +283,7 @@ class Runnable(Generic[Input, Output], ABC): ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement ainvoke() in subclasses.""" - config = config or {} + config = config or _empty_config() callback_manager = _get_async_callback_manager(config) run_manager = await callback_manager.on_chain_start( dumpd(self), @@ -322,7 +322,7 @@ class Runnable(Generic[Input, Output], ABC): final_output: Optional[Output] = None final_output_supported = True - config = config or {} + config = config or _empty_config() callback_manager = _get_callback_manager(config) run_manager = callback_manager.on_chain_start( dumpd(self), @@ -387,7 +387,7 @@ class Runnable(Generic[Input, Output], ABC): final_output: Optional[Output] = None final_output_supported = True - config = config or {} + config = config or _empty_config() callback_manager = _get_async_callback_manager(config) run_manager = await callback_manager.on_chain_start( dumpd(self), @@ -462,7 +462,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: # setup callbacks - config = config or {} + config = config or _empty_config() callback_manager = _get_callback_manager(config) # start the root run run_manager = callback_manager.on_chain_start( @@ -495,7 +495,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None ) -> Output: # setup callbacks - config = config or {} + config = config or _empty_config() callback_manager = _get_async_callback_manager(config) # start the root run run_manager = await callback_manager.on_chain_start( @@ -1068,7 +1068,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): from langchain.callbacks.manager import CallbackManager # setup callbacks - config = config or {} + config = config or _empty_config() callback_manager = CallbackManager.configure( inheritable_callbacks=config.get("callbacks"), local_callbacks=None, @@ -1091,7 +1091,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): step.invoke, input, # mark each step as a child run - patch_config(config, run_manager.get_child()), + patch_config(deepcopy(config), run_manager.get_child()), ) for step in steps.values() ] @@ -1108,7 +1108,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): self, input: Input, config: Optional[RunnableConfig] = None ) -> Dict[str, Any]: # setup callbacks - config = config or {} + config = config or _empty_config() callback_manager = _get_async_callback_manager(config) # start the root run run_manager = await callback_manager.on_chain_start( diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py b/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py index d0a3fb38d9d..dce548fc69f 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py @@ -1,6 +1,13 @@ import pytest -from langchain.schema.runnable import GetLocalVar, PutLocalVar +from langchain import PromptTemplate +from langchain.llms import FakeListLLM +from langchain.schema.runnable import ( + GetLocalVar, + PutLocalVar, + RunnablePassthrough, + RunnableSequence, +) @pytest.mark.asyncio @@ -29,3 +36,30 @@ def test_get_missing_var_invoke() -> None: with pytest.raises(KeyError): runnable.invoke("foo") + +def test_get_in_map() -> None: + runnable: RunnableSequence = PutLocalVar("input") | {"bar": GetLocalVar("input")} + assert runnable.invoke("foo") == {"bar": "foo"} + + +def test_cant_put_in_map() -> None: + runnable: RunnableSequence = {"bar": PutLocalVar("input")} | GetLocalVar("input") + with pytest.raises(KeyError): + runnable.invoke("foo") + + +def test_get_passthrough_key() -> None: + runnable = PutLocalVar("input") | GetLocalVar("input", passthrough_key="output") + assert runnable.invoke("foo") == {"input": "foo", "output": "foo"} + + +def test_multi_step_sequence() -> None: + prompt = PromptTemplate.from_template("say {foo}") + runnable = ( + PutLocalVar("foo") + | {"foo": RunnablePassthrough()} + | prompt + | FakeListLLM(responses=["bar"]) + | GetLocalVar("foo", passthrough_key="output") + ) + assert runnable.invoke("hello") == {"foo": "hello", "output": "bar"} From 7fe474d19820e5ffd65d30f03446301a644d8a7c Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 18 Aug 2023 10:02:11 +0100 Subject: [PATCH 10/23] Update snapshots --- .../schema/{ => runnable}/__snapshots__/test_runnable.ambr | 1 + 1 file changed, 1 insertion(+) rename libs/langchain/tests/unit_tests/schema/{ => runnable}/__snapshots__/test_runnable.ambr (99%) diff --git a/libs/langchain/tests/unit_tests/schema/__snapshots__/test_runnable.ambr b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr similarity index 99% rename from libs/langchain/tests/unit_tests/schema/__snapshots__/test_runnable.ambr rename to libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr index 2d2872c1479..4a59ae63088 100644 --- a/libs/langchain/tests/unit_tests/schema/__snapshots__/test_runnable.ambr +++ b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr @@ -1331,6 +1331,7 @@ "lc": 1, "type": "not_implemented", "id": [ + "runnable", "test_runnable", "FakeRetriever" ] From c1b1666ec850e465bf93bd01d34e09fc457076cc Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 18 Aug 2023 10:02:29 +0100 Subject: [PATCH 11/23] Ensure config defaults apply even when a config is passed in --- .../langchain/schema/runnable/base.py | 36 +++++++++---------- .../schema/runnable/test_runnable.py | 18 +++++++--- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index c91456394e2..0d9df2baeea 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -35,8 +35,11 @@ from langchain.schema.runnable.utils import ( from langchain.utils.aiter import atee, py_anext -def _empty_config() -> RunnableConfig: - return RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={}) +def _ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: + empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={}) + if config is not None: + empty.update(config) + return empty def _get_callback_manager(config: Mapping) -> Any: @@ -238,10 +241,7 @@ class Runnable(Generic[Input, Output], ABC): return ( config if isinstance(config, list) - else [ - deepcopy(config) if config is not None else _empty_config() - for _ in range(length) - ] + else [deepcopy(_ensure_config(config)) for _ in range(length)] ) def _call_with_config( @@ -253,7 +253,7 @@ class Runnable(Generic[Input, Output], ABC): ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement invoke() in subclasses.""" - config = config or _empty_config() + config = _ensure_config(config) callback_manager = _get_callback_manager(config) run_manager = callback_manager.on_chain_start( dumpd(self), @@ -283,7 +283,7 @@ class Runnable(Generic[Input, Output], ABC): ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement ainvoke() in subclasses.""" - config = config or _empty_config() + config = _ensure_config(config) callback_manager = _get_async_callback_manager(config) run_manager = await callback_manager.on_chain_start( dumpd(self), @@ -322,7 +322,7 @@ class Runnable(Generic[Input, Output], ABC): final_output: Optional[Output] = None final_output_supported = True - config = config or _empty_config() + config = _ensure_config(config) callback_manager = _get_callback_manager(config) run_manager = callback_manager.on_chain_start( dumpd(self), @@ -387,7 +387,7 @@ class Runnable(Generic[Input, Output], ABC): final_output: Optional[Output] = None final_output_supported = True - config = config or _empty_config() + config = _ensure_config(config) callback_manager = _get_async_callback_manager(config) run_manager = await callback_manager.on_chain_start( dumpd(self), @@ -462,7 +462,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: # setup callbacks - config = config or _empty_config() + config = _ensure_config(config) callback_manager = _get_callback_manager(config) # start the root run run_manager = callback_manager.on_chain_start( @@ -495,7 +495,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None ) -> Output: # setup callbacks - config = config or _empty_config() + config = _ensure_config(config) callback_manager = _get_async_callback_manager(config) # start the root run run_manager = await callback_manager.on_chain_start( @@ -724,7 +724,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: # setup callbacks - config = config or _empty_config() + config = _ensure_config(config) callback_manager = _get_callback_manager(config) # start the root run run_manager = callback_manager.on_chain_start( @@ -753,7 +753,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None ) -> Output: # setup callbacks - config = config or _empty_config() + config = _ensure_config(config) callback_manager = _get_async_callback_manager(config) # start the root run run_manager = await callback_manager.on_chain_start( @@ -899,7 +899,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None ) -> Iterator[Output]: # setup callbacks - config = config or _empty_config() + config = _ensure_config(config) callback_manager = _get_callback_manager(config) # start the root run run_manager = callback_manager.on_chain_start( @@ -966,7 +966,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None ) -> AsyncIterator[Output]: # setup callbacks - config = config or _empty_config() + config = _ensure_config(config) callback_manager = _get_async_callback_manager(config) # start the root run run_manager = await callback_manager.on_chain_start( @@ -1068,7 +1068,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): from langchain.callbacks.manager import CallbackManager # setup callbacks - config = config or _empty_config() + config = _ensure_config(config) callback_manager = CallbackManager.configure( inheritable_callbacks=config.get("callbacks"), local_callbacks=None, @@ -1108,7 +1108,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): self, input: Input, config: Optional[RunnableConfig] = None ) -> Dict[str, Any]: # setup callbacks - config = config or _empty_config() + config = _ensure_config(config) callback_manager = _get_async_callback_manager(config) # start the root run run_manager = await callback_manager.on_chain_start( diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index c0cae4d9bdf..8bfecb1821d 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -134,8 +134,12 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7] assert spy.call_args_list == [ - mocker.call("hello", dict(tags=["a-tag"])), - mocker.call("wooorld", dict(tags=["a-tag"])), + mocker.call( + "hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={}) + ), + mocker.call( + "wooorld", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={}) + ), ] spy.reset_mock() @@ -156,8 +160,14 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: 7, ] assert spy.call_args_list == [ - mocker.call("hello", dict(metadata={"key": "value"})), - mocker.call("wooorld", dict(metadata={"key": "value"})), + mocker.call( + "hello", + dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}), + ), + mocker.call( + "wooorld", + dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}), + ), ] From a5e7dcec61cdcaf2c075b5e83117ee3ab14e92c3 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 18 Aug 2023 10:03:28 +0100 Subject: [PATCH 12/23] Lint --- libs/langchain/tests/unit_tests/schema/runnable/test_locals.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py b/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py index dce548fc69f..8f8755a9644 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py @@ -28,7 +28,7 @@ def test_missing_config() -> None: PutLocalVar("input").invoke("foo") with pytest.raises(ValueError): - GetLocalVar("input").invoke("foo") + GetLocalVar[str, str]("input").invoke("foo") def test_get_missing_var_invoke() -> None: From 8ddaaf3d4100ddbc6fc8e7fa9df39d6fc6a67c9a Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 18 Aug 2023 10:10:35 +0100 Subject: [PATCH 13/23] Move config helpers --- .../langchain/schema/runnable/base.py | 82 +++++++------------ .../langchain/schema/runnable/config.py | 28 ++++++- 2 files changed, 57 insertions(+), 53 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 0d9df2baeea..5a1d5b29e4f 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -28,40 +28,18 @@ from langchain.callbacks.base import BaseCallbackManager from langchain.load.dump import dumpd from langchain.load.serializable import Serializable from langchain.pydantic_v1 import Field -from langchain.schema.runnable.config import RunnableConfig +from langchain.schema.runnable.config import ( + RunnableConfig, + ensure_config, + get_async_callback_manager_for_config, + get_callback_manager_for_config, +) from langchain.schema.runnable.utils import ( gather_with_concurrency, ) from langchain.utils.aiter import atee, py_anext -def _ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: - empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={}) - if config is not None: - empty.update(config) - return empty - - -def _get_callback_manager(config: Mapping) -> Any: - from langchain.callbacks.manager import CallbackManager - - return CallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - inheritable_tags=config.get("tags"), - inheritable_metadata=config.get("metadata"), - ) - - -def _get_async_callback_manager(config: Mapping) -> Any: - from langchain.callbacks.manager import AsyncCallbackManager - - return AsyncCallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - inheritable_tags=config.get("tags"), - inheritable_metadata=config.get("metadata"), - ) - - Input = TypeVar("Input") # Output type should implement __concat__, as eg str, list, dict do Output = TypeVar("Output") @@ -241,7 +219,7 @@ class Runnable(Generic[Input, Output], ABC): return ( config if isinstance(config, list) - else [deepcopy(_ensure_config(config)) for _ in range(length)] + else [deepcopy(ensure_config(config)) for _ in range(length)] ) def _call_with_config( @@ -253,8 +231,8 @@ class Runnable(Generic[Input, Output], ABC): ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement invoke() in subclasses.""" - config = _ensure_config(config) - callback_manager = _get_callback_manager(config) + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) run_manager = callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input}, @@ -283,8 +261,8 @@ class Runnable(Generic[Input, Output], ABC): ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement ainvoke() in subclasses.""" - config = _ensure_config(config) - callback_manager = _get_async_callback_manager(config) + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) run_manager = await callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input}, @@ -322,8 +300,8 @@ class Runnable(Generic[Input, Output], ABC): final_output: Optional[Output] = None final_output_supported = True - config = _ensure_config(config) - callback_manager = _get_callback_manager(config) + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) run_manager = callback_manager.on_chain_start( dumpd(self), {"input": ""}, @@ -387,8 +365,8 @@ class Runnable(Generic[Input, Output], ABC): final_output: Optional[Output] = None final_output_supported = True - config = _ensure_config(config) - callback_manager = _get_async_callback_manager(config) + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) run_manager = await callback_manager.on_chain_start( dumpd(self), {"input": ""}, @@ -462,8 +440,8 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: # setup callbacks - config = _ensure_config(config) - callback_manager = _get_callback_manager(config) + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) # start the root run run_manager = callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input} @@ -495,8 +473,8 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None ) -> Output: # setup callbacks - config = _ensure_config(config) - callback_manager = _get_async_callback_manager(config) + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input} @@ -724,8 +702,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: # setup callbacks - config = _ensure_config(config) - callback_manager = _get_callback_manager(config) + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) # start the root run run_manager = callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input} @@ -753,8 +731,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None ) -> Output: # setup callbacks - config = _ensure_config(config) - callback_manager = _get_async_callback_manager(config) + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input} @@ -899,8 +877,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None ) -> Iterator[Output]: # setup callbacks - config = _ensure_config(config) - callback_manager = _get_callback_manager(config) + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) # start the root run run_manager = callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input} @@ -966,8 +944,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None ) -> AsyncIterator[Output]: # setup callbacks - config = _ensure_config(config) - callback_manager = _get_async_callback_manager(config) + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input} @@ -1068,7 +1046,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): from langchain.callbacks.manager import CallbackManager # setup callbacks - config = _ensure_config(config) + config = ensure_config(config) callback_manager = CallbackManager.configure( inheritable_callbacks=config.get("callbacks"), local_callbacks=None, @@ -1108,8 +1086,8 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): self, input: Input, config: Optional[RunnableConfig] = None ) -> Dict[str, Any]: # setup callbacks - config = _ensure_config(config) - callback_manager = _get_async_callback_manager(config) + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( dumpd(self), {"input": input} diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index f2bf28fcb57..cd620077e1f 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -1,8 +1,9 @@ from __future__ import annotations -from typing import Any, Dict, List, TypedDict +from typing import Any, Dict, List, Optional, TypedDict from langchain.callbacks.base import Callbacks +from langchain.callbacks.manager import CallbackManager, AsyncCallbackManager class RunnableConfig(TypedDict, total=False): @@ -30,3 +31,28 @@ class RunnableConfig(TypedDict, total=False): """ Local variables """ + + +def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: + empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={}) + if config is not None: + empty.update(config) + return empty + + +def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager: + return CallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + inheritable_tags=config.get("tags"), + inheritable_metadata=config.get("metadata"), + ) + + +def get_async_callback_manager_for_config( + config: RunnableConfig, +) -> AsyncCallbackManager: + return AsyncCallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + inheritable_tags=config.get("tags"), + inheritable_metadata=config.get("metadata"), + ) From 46f3850794f5fc14477d5545c6d1edd6bbfeca1a Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 18 Aug 2023 10:25:41 +0100 Subject: [PATCH 14/23] Lint --- .../langchain/schema/runnable/base.py | 1 - .../langchain/schema/runnable/config.py | 2 +- .../langchain/schema/runnable/locals.py | 78 +++++++++---------- 3 files changed, 39 insertions(+), 42 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index c0caa6d9a20..1ca853174a2 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -52,7 +52,6 @@ from langchain.schema.runnable.utils import ( from langchain.utils.aiter import atee, py_anext from langchain.utils.iter import safetee - Input = TypeVar("Input") # Output type should implement __concat__, as eg str, list, dict do Output = TypeVar("Output") diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index cd620077e1f..716fc361161 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import Any, Dict, List, Optional, TypedDict from langchain.callbacks.base import Callbacks -from langchain.callbacks.manager import CallbackManager, AsyncCallbackManager +from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager class RunnableConfig(TypedDict, total=False): diff --git a/libs/langchain/langchain/schema/runnable/locals.py b/libs/langchain/langchain/schema/runnable/locals.py index 5061dbf38c1..6d668059edf 100644 --- a/libs/langchain/langchain/schema/runnable/locals.py +++ b/libs/langchain/langchain/schema/runnable/locals.py @@ -2,6 +2,10 @@ from __future__ import annotations from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) from langchain.load.serializable import Serializable from langchain.schema.runnable.base import Input, Output, Runnable from langchain.schema.runnable.config import RunnableConfig @@ -20,30 +24,12 @@ class PutLocalVar(RunnablePassthrough): def __init__(self, key: Union[str, Mapping[str, str]], **kwargs: Any) -> None: super().__init__(key=key, **kwargs) - def _put(self, input: Input, *, config: Optional[RunnableConfig] = None) -> None: - if config is None: - raise ValueError( - "PutLocalVar should only be used in a RunnableSequence, and should " - "therefore always receive a non-null config." - ) - if isinstance(self.key, str): - config["_locals"][self.key] = input - elif isinstance(self.key, Mapping): - if not isinstance(input, Mapping): - raise TypeError( - f"Received key of type Mapping but input of type {type(input)}. " - f"input is expected to be of type Mapping when key is Mapping." - ) - for input_key, put_key in self.key.items(): - config["_locals"][put_key] = input[input_key] - else: - raise TypeError( - f"`key` should be a string or Mapping[str, str], received type " - f"{(type(self.key))}." - ) - def _concat_put( - self, input: Input, *, config: Optional[RunnableConfig] = None + self, + input: Input, + *, + config: Optional[RunnableConfig] = None, + replace: bool = False, ) -> None: if config is None: raise ValueError( @@ -51,7 +37,7 @@ class PutLocalVar(RunnablePassthrough): "therefore always receive a non-null config." ) if isinstance(self.key, str): - if self.key not in config["_locals"]: + if self.key not in config["_locals"] or replace: config["_locals"][self.key] = input else: config["_locals"][self.key] += input @@ -62,7 +48,7 @@ class PutLocalVar(RunnablePassthrough): f"input is expected to be of type Mapping when key is Mapping." ) for input_key, put_key in self.key.items(): - if put_key not in config["_locals"]: + if put_key not in config["_locals"] or replace: config["_locals"][put_key] = input[input_key] else: config["_locals"][put_key] += input[input_key] @@ -73,24 +59,30 @@ class PutLocalVar(RunnablePassthrough): ) def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input: - self._put(input, config=config) + self._concat_put(input, config=config, replace=True) return super().invoke(input, config=config) async def ainvoke( - self, input: Input, config: RunnableConfig | None = None + self, input: Input, config: Optional[RunnableConfig] = None ) -> Input: - self._put(input, config=config) + self._concat_put(input, config=config, replace=True) return await super().ainvoke(input, config=config) def transform( - self, input: Iterator[Input], config: RunnableConfig | None = None + self, + input: Iterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], ) -> Iterator[Input]: for chunk in super().transform(input, config=config): self._concat_put(chunk, config=config) yield chunk async def atransform( - self, input: AsyncIterator[Input], config: RunnableConfig | None = None + self, + input: AsyncIterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], ) -> AsyncIterator[Input]: async for chunk in super().atransform(input, config=config): self._concat_put(chunk, config=config) @@ -113,19 +105,27 @@ class GetLocalVar( def __init__(self, key: str, **kwargs: Any) -> None: super().__init__(key=key, **kwargs) - def _get(self, full_input: Dict) -> Union[Output, Dict[str, Union[Input, Output]]]: + def _get( + self, + input: Input, + run_manager: Union[CallbackManagerForChainRun, Any], + config: RunnableConfig, + ) -> Union[Output, Dict[str, Union[Input, Output]]]: if self.passthrough_key: return { - self.key: full_input["locals"][self.key], - self.passthrough_key: full_input["input"], + self.key: config["_locals"][self.key], + self.passthrough_key: input, } else: - return full_input["locals"][self.key] + return config["_locals"][self.key] async def _aget( - self, full_input: Dict + self, + input: Input, + run_manager: AsyncCallbackManagerForChainRun, + config: RunnableConfig, ) -> Union[Output, Dict[str, Union[Input, Output]]]: - return self._get(full_input) + return self._get(input, run_manager, config) def invoke( self, input: Input, config: Optional[RunnableConfig] = None @@ -136,8 +136,7 @@ class GetLocalVar( "therefore always receive a non-null config." ) - log_input = {"input": input, "locals": config["_locals"]} - return self._call_with_config(self._get, log_input, config) + return self._call_with_config(self._get, input, config) async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None @@ -148,5 +147,4 @@ class GetLocalVar( "therefore always receive a non-null config." ) - log_input = {"input": input, "locals": config["_locals"]} - return await self._acall_with_config(self._aget, log_input, config) + return await self._acall_with_config(self._aget, input, config) From 1baedc4e1802fc13de49116ba6becd20c0860b71 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 18 Aug 2023 10:28:39 +0100 Subject: [PATCH 15/23] Move patch_config --- libs/langchain/langchain/schema/runnable/base.py | 14 +------------- libs/langchain/langchain/schema/runnable/config.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 1ca853174a2..5fec1c86ca7 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -34,7 +34,6 @@ if TYPE_CHECKING: ) -from langchain.callbacks.base import BaseCallbackManager from langchain.load.dump import dumpd from langchain.load.serializable import Serializable from langchain.pydantic_v1 import Field @@ -43,6 +42,7 @@ from langchain.schema.runnable.config import ( ensure_config, get_async_callback_manager_for_config, get_callback_manager_for_config, + patch_config, ) from langchain.schema.runnable.utils import ( accepts_run_manager, @@ -1472,18 +1472,6 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): yield item -def patch_config( - config: RunnableConfig, - callback_manager: BaseCallbackManager, - _locals: Optional[Dict[str, Any]] = None, -) -> RunnableConfig: - config = config.copy() - config["callbacks"] = callback_manager - if _locals is not None: - config["_locals"] = _locals - return config - - def coerce_to_runnable( thing: Union[ Runnable[Input, Output], diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index 716fc361161..00408b7ee6c 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import Any, Dict, List, Optional, TypedDict -from langchain.callbacks.base import Callbacks +from langchain.callbacks.base import BaseCallbackManager, Callbacks from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager @@ -40,6 +40,15 @@ def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: return empty +def patch_config( + config: RunnableConfig, + callbacks: BaseCallbackManager, +) -> RunnableConfig: + config = config.copy() + config["callbacks"] = callbacks + return config + + def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager: return CallbackManager.configure( inheritable_callbacks=config.get("callbacks"), From ddcb4ff5fb3f0ccf3871c2c86744fd8daa436435 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 18 Aug 2023 10:30:42 +0100 Subject: [PATCH 16/23] Li t --- libs/langchain/langchain/smith/evaluation/runner_utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/libs/langchain/langchain/smith/evaluation/runner_utils.py b/libs/langchain/langchain/smith/evaluation/runner_utils.py index cc3c46dff14..64139f95e30 100644 --- a/libs/langchain/langchain/smith/evaluation/runner_utils.py +++ b/libs/langchain/langchain/smith/evaluation/runner_utils.py @@ -654,9 +654,7 @@ async def _arun_chain( else: output = await chain.acall(inputs_, callbacks=callbacks, tags=tags) else: - runnable_config = RunnableConfig( - tags=tags or [], callbacks=callbacks, _locals={} - ) + runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks) output = await chain.ainvoke(inputs_, config=runnable_config) return output @@ -977,9 +975,7 @@ def _run_chain( else: output = chain(inputs_, callbacks=callbacks, tags=tags) else: - runnable_config = RunnableConfig( - tags=tags or [], callbacks=callbacks, _locals={} - ) + runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks) output = chain.invoke(inputs_, config=runnable_config) return output From 6ae58da668f375d4bd5ae162fe21bf1f140ffc36 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 18 Aug 2023 10:53:10 +0100 Subject: [PATCH 17/23] Assign defaults in batch calls --- libs/langchain/langchain/schema/runnable/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 5fec1c86ca7..fcba9c4c1aa 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -235,7 +235,7 @@ class Runnable(Generic[Input, Output], ABC): ) return ( - config + list(map(ensure_config, config)) if isinstance(config, list) else [deepcopy(ensure_config(config)) for _ in range(length)] ) From d3f10d2f4f49c88747836f281a1651e696f11e20 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 18 Aug 2023 11:36:16 +0100 Subject: [PATCH 18/23] Update test --- .../tests/unit_tests/schema/runnable/test_runnable.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 4d02a07df4a..5d140d2aded 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -127,8 +127,13 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: ["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})] ) == [5, 7] assert spy.call_args_list == [ - mocker.call("hello", dict(tags=["a-tag"])), - mocker.call("wooorld", dict(metadata={"key": "value"})), + mocker.call( + "hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={}) + ), + mocker.call( + "wooorld", + dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}), + ), ] spy.reset_mock() From 354c42afd20e9cf93ff1a6cd263b4372c5136b22 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 18 Aug 2023 15:30:30 +0100 Subject: [PATCH 19/23] Lint --- libs/langchain/langchain/schema/runnable/locals.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/locals.py b/libs/langchain/langchain/schema/runnable/locals.py index 6d668059edf..755a709fc95 100644 --- a/libs/langchain/langchain/schema/runnable/locals.py +++ b/libs/langchain/langchain/schema/runnable/locals.py @@ -63,7 +63,10 @@ class PutLocalVar(RunnablePassthrough): return super().invoke(input, config=config) async def ainvoke( - self, input: Input, config: Optional[RunnableConfig] = None + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], ) -> Input: self._concat_put(input, config=config, replace=True) return await super().ainvoke(input, config=config) @@ -139,7 +142,10 @@ class GetLocalVar( return self._call_with_config(self._get, input, config) async def ainvoke( - self, input: Input, config: Optional[RunnableConfig] = None + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], ) -> Union[Output, Dict[str, Union[Input, Output]]]: if config is None: raise ValueError( From 182b059bf4d6bfbbd3204a83985dbe90e9613285 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Mon, 21 Aug 2023 17:31:38 -0700 Subject: [PATCH 20/23] param --- .../unit_tests/schema/runnable/test_locals.py | 84 +++++++++++-------- 1 file changed, 49 insertions(+), 35 deletions(-) diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py b/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py index 8f8755a9644..0430c03c824 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py @@ -1,3 +1,5 @@ +from typing import Any, Callable, Type + import pytest from langchain import PromptTemplate @@ -10,30 +12,42 @@ from langchain.schema.runnable import ( ) -@pytest.mark.asyncio -async def test_put_get() -> None: +@pytest.mark.parametrize( + ("method", "input", "output"), + [ + (lambda r, x: r.invoke(x), "foo", "foo"), + (lambda r, x: r.batch(x), ["foo", "bar"], ["foo", "bar"]), + (lambda r, x: list(r.stream(x))[0], "foo", "foo"), + ], +) +def test_put_get(method: Callable, input: Any, output: Any) -> None: runnable = PutLocalVar("input") | GetLocalVar("input") - assert runnable.invoke("foo") == "foo" - assert runnable.batch(["foo", "bar"]) == ["foo", "bar"] - assert list(runnable.stream("foo"))[0] == "foo" - - assert await runnable.ainvoke("foo") == "foo" - assert await runnable.abatch(["foo", "bar"]) == ["foo", "bar"] - async for x in runnable.astream("foo"): - assert x == "foo" + assert method(runnable, input) == output -def test_missing_config() -> None: - with pytest.raises(ValueError): - PutLocalVar("input").invoke("foo") - - with pytest.raises(ValueError): - GetLocalVar[str, str]("input").invoke("foo") +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("method", "input", "output"), + [ + (lambda r, x: r.ainvoke(x), "foo", "foo"), + (lambda r, x: r.abatch(x), ["foo", "bar"], ["foo", "bar"]), + ], +) +async def test_put_get_async(method: Callable, input: Any, output: Any) -> None: + runnable = PutLocalVar("input") | GetLocalVar("input") + assert await method(runnable, input) == output -def test_get_missing_var_invoke() -> None: - runnable = PutLocalVar("input") | GetLocalVar("missing") - with pytest.raises(KeyError): +@pytest.mark.parametrize( + ("runnable", "error"), + [ + (PutLocalVar("input"), ValueError), + (GetLocalVar("input"), ValueError), + (PutLocalVar("input") | GetLocalVar("missing"), KeyError), + ], +) +def test_incorrect_usage(runnable: RunnableSequence, error: Type[Exception]) -> None: + with pytest.raises(error): runnable.invoke("foo") @@ -42,24 +56,24 @@ def test_get_in_map() -> None: assert runnable.invoke("foo") == {"bar": "foo"} -def test_cant_put_in_map() -> None: +def test_put_in_map() -> None: runnable: RunnableSequence = {"bar": PutLocalVar("input")} | GetLocalVar("input") with pytest.raises(KeyError): runnable.invoke("foo") -def test_get_passthrough_key() -> None: - runnable = PutLocalVar("input") | GetLocalVar("input", passthrough_key="output") - assert runnable.invoke("foo") == {"input": "foo", "output": "foo"} - - -def test_multi_step_sequence() -> None: - prompt = PromptTemplate.from_template("say {foo}") - runnable = ( - PutLocalVar("foo") - | {"foo": RunnablePassthrough()} - | prompt - | FakeListLLM(responses=["bar"]) - | GetLocalVar("foo", passthrough_key="output") - ) - assert runnable.invoke("hello") == {"foo": "hello", "output": "bar"} +@pytest.mark.parametrize( + "runnable", + [ + PutLocalVar("input") | GetLocalVar("input", passthrough_key="output"), + ( + PutLocalVar("input") + | {"input": RunnablePassthrough()} + | PromptTemplate.from_template("say {input}") + | FakeListLLM(responses=["hello"]) + | GetLocalVar("input", passthrough_key="output") + ), + ], +) +def test_put_get_sequence(runnable: RunnableSequence) -> None: + assert runnable.invoke("hello") == {"input": "hello", "output": "hello"} From a9bf409a0900730e88d2f1ffd087c818137fe8df Mon Sep 17 00:00:00 2001 From: Bagatur Date: Mon, 21 Aug 2023 17:37:07 -0700 Subject: [PATCH 21/23] param --- .../unit_tests/schema/runnable/test_locals.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py b/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py index 0430c03c824..ee07c0cfc6e 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py @@ -75,5 +75,19 @@ def test_put_in_map() -> None: ), ], ) -def test_put_get_sequence(runnable: RunnableSequence) -> None: - assert runnable.invoke("hello") == {"input": "hello", "output": "hello"} +@pytest.mark.parametrize( + ("method", "input", "output"), + [ + (lambda r, x: r.invoke(x), "hello", {"input": "hello", "output": "hello"}), + (lambda r, x: r.batch(x), ["hello"], [{"input": "hello", "output": "hello"}]), + ( + lambda r, x: list(r.stream(x))[0], + "hello", + {"input": "hello", "output": "hello"}, + ), + ], +) +def test_put_get_sequence( + runnable: RunnableSequence, method: Callable, input: Any, output: Any +) -> None: + assert method(runnable, input) == output From 4e7e6bfe0a7bd15c4ccd72ed33fe1b35b47be3ef Mon Sep 17 00:00:00 2001 From: Bagatur Date: Mon, 21 Aug 2023 18:01:49 -0700 Subject: [PATCH 22/23] revert --- libs/langchain/langchain/chains/base.py | 23 +++++++++++------ libs/langchain/langchain/chat_models/base.py | 20 +++++++++------ libs/langchain/langchain/llms/base.py | 27 ++++++++++++-------- libs/langchain/langchain/schema/retriever.py | 20 +++++++++------ 4 files changed, 56 insertions(+), 34 deletions(-) diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 8a49784f7dc..5a21dc6a661 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -63,10 +63,13 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): **kwargs: Any, ) -> Dict[str, Any]: config = config or {} - config_kwargs: Dict = { - k: config.get(k) for k in ("callbacks", "tags", "metadata") - } - return self(input, **config_kwargs, **kwargs) + return self( + input, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + **kwargs, + ) async def ainvoke( self, @@ -79,11 +82,15 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): return await asyncio.get_running_loop().run_in_executor( None, partial(self.invoke, input, config, **kwargs) ) + config = config or {} - config_kwargs: Dict = { - k: config.get(k) for k in ("callbacks", "tags", "metadata") - } - return await self.acall(input, **config_kwargs, **kwargs) + return await self.acall( + input, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + **kwargs, + ) memory: Optional[BaseMemory] = None """Optional memory object. Defaults to None. diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index d4c582c19e0..09199e30dc9 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -105,15 +105,17 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): **kwargs: Any, ) -> BaseMessageChunk: config = config or {} - config_kwargs: Dict = { - k: config.get(k) for k in ("callbacks", "tags", "metadata") - } return cast( BaseMessageChunk, cast( ChatGeneration, self.generate_prompt( - [self._convert_input(input)], stop=stop, **config_kwargs, **kwargs + [self._convert_input(input)], + stop=stop, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + **kwargs, ).generations[0][0], ).message, ) @@ -133,11 +135,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): ) config = config or {} - config_kwargs: Dict = { - k: config.get(k) for k in ("callbacks", "tags", "metadata") - } llm_result = await self.agenerate_prompt( - [self._convert_input(input)], stop=stop, **config_kwargs, **kwargs + [self._convert_input(input)], + stop=stop, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + **kwargs, ) return cast( BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message diff --git a/libs/langchain/langchain/llms/base.py b/libs/langchain/langchain/llms/base.py index 401fe61d067..a833487ffb1 100644 --- a/libs/langchain/langchain/llms/base.py +++ b/libs/langchain/langchain/llms/base.py @@ -220,13 +220,18 @@ class BaseLLM(BaseLanguageModel[str], ABC): **kwargs: Any, ) -> str: config = config or {} - config_kwargs: Dict = { - k: config.get(k) for k in ("callbacks", "tags", "metadata") - } - result = self.generate_prompt( - [self._convert_input(input)], stop=stop, **config_kwargs, **kwargs + return ( + self.generate_prompt( + [self._convert_input(input)], + stop=stop, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + **kwargs, + ) + .generations[0][0] + .text ) - return result.generations[0][0].text async def ainvoke( self, @@ -243,11 +248,13 @@ class BaseLLM(BaseLanguageModel[str], ABC): ) config = config or {} - config_kwargs: Dict = { - k: config.get(k) for k in ("callbacks", "tags", "metadata") - } llm_result = await self.agenerate_prompt( - [self._convert_input(input)], stop=stop, **config_kwargs, **kwargs + [self._convert_input(input)], + stop=stop, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + **kwargs, ) return llm_result.generations[0][0].text diff --git a/libs/langchain/langchain/schema/retriever.py b/libs/langchain/langchain/schema/retriever.py index 55a1acb086e..5da50e1497e 100644 --- a/libs/langchain/langchain/schema/retriever.py +++ b/libs/langchain/langchain/schema/retriever.py @@ -108,10 +108,12 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): self, input: str, config: Optional[RunnableConfig] = None ) -> List[Document]: config = config or {} - config_kwargs: Dict = { - k: config.get(k) for k in ("callbacks", "tags", "metadata") - } - return self.get_relevant_documents(input, **config_kwargs) + return self.get_relevant_documents( + input, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + ) async def ainvoke( self, @@ -124,10 +126,12 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC): return await super().ainvoke(input, config) config = config or {} - config_kwargs: Dict = { - k: config.get(k) for k in ("callbacks", "tags", "metadata") - } - return await self.aget_relevant_documents(input, **config_kwargs) + return await self.aget_relevant_documents( + input, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + ) @abstractmethod def _get_relevant_documents( From ef2500584cbb50527b738d3af4b10b8ead56f0f9 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Wed, 23 Aug 2023 10:15:45 -0700 Subject: [PATCH 23/23] fmt --- libs/langchain/langchain/schema/runnable/__init__.py | 2 +- .../langchain/schema/runnable/{locals.py => _locals.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename libs/langchain/langchain/schema/runnable/{locals.py => _locals.py} (100%) diff --git a/libs/langchain/langchain/schema/runnable/__init__.py b/libs/langchain/langchain/schema/runnable/__init__.py index bae6aebb024..16f99324b0f 100644 --- a/libs/langchain/langchain/schema/runnable/__init__.py +++ b/libs/langchain/langchain/schema/runnable/__init__.py @@ -1,3 +1,4 @@ +from langchain.schema.runnable._locals import GetLocalVar, PutLocalVar from langchain.schema.runnable.base import ( Runnable, RunnableBinding, @@ -7,7 +8,6 @@ from langchain.schema.runnable.base import ( RunnableWithFallbacks, ) from langchain.schema.runnable.config import RunnableConfig -from langchain.schema.runnable.locals import GetLocalVar, PutLocalVar from langchain.schema.runnable.passthrough import RunnablePassthrough from langchain.schema.runnable.router import RouterInput, RouterRunnable diff --git a/libs/langchain/langchain/schema/runnable/locals.py b/libs/langchain/langchain/schema/runnable/_locals.py similarity index 100% rename from libs/langchain/langchain/schema/runnable/locals.py rename to libs/langchain/langchain/schema/runnable/_locals.py