diff --git a/libs/langchain/langchain/schema/runnable.py b/libs/langchain/langchain/schema/runnable.py index 2669409a3a2..0facc1ea6d8 100644 --- a/libs/langchain/langchain/schema/runnable.py +++ b/libs/langchain/langchain/schema/runnable.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor +from copy import deepcopy from typing import ( Any, AsyncIterator, @@ -189,6 +190,15 @@ class Runnable(Generic[Input, Output], ABC): ) return output + def _invoke_with_locals( + self, + input: Input, + _locals: Dict[str, Any], + *, + config: Optional[RunnableConfig] = None, + ) -> Output: + return self.invoke(input, config=config) + class RunnableSequence(Serializable, Runnable[Input, Output]): first: Runnable[Input, Any] @@ -249,6 +259,16 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): ) def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: + locals: Dict[str, Any] = {} + return self._invoke_with_locals(input, locals, config=config) + + def _invoke_with_locals( + self, + input: Input, + _locals: Dict[str, Any], + *, + config: Optional[RunnableConfig] = None, + ) -> Output: from langchain.callbacks.manager import CallbackManager # setup callbacks @@ -270,10 +290,11 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): # invoke all steps in sequence try: for step in self.steps: - input = step.invoke( + input = step._invoke_with_locals( input, + _locals, # mark each step as a child run - _patch_config(config, run_manager.get_child()), + config=_patch_config(config, run_manager.get_child()), ) # finish the root run except (KeyboardInterrupt, Exception) as e: @@ -595,6 +616,15 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): def invoke( self, input: Input, config: Optional[RunnableConfig] = None + ) -> Dict[str, Any]: + return self._invoke_with_locals(input, {}, config=config) + + def _invoke_with_locals( + self, + input: Input, + _locals: Dict[str, Any], + *, + config: Optional[RunnableConfig] = None, ) -> Dict[str, Any]: from langchain.callbacks.manager import CallbackManager @@ -619,10 +649,12 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): with ThreadPoolExecutor() as executor: futures = [ executor.submit( - step.invoke, + step._invoke_with_locals, input, + # locals are read-only in a map step + deepcopy(_locals), # mark each step as a child run - _patch_config(config, run_manager.get_child()), + config=_patch_config(config, run_manager.get_child()), ) for step in steps.values() ] @@ -709,6 +741,48 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]): return self._call_with_config(lambda x: x, input, config) +class PutLocalVar(Serializable, Runnable[Input, Input]): + key: str + + def __init__(self, key: str, **kwargs: Any) -> None: + super().__init__(key=key, **kwargs) + + def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input: + raise NotImplementedError + + def _invoke_with_locals( + self, + input: Input, + _locals: Dict[str, Any], + *, + config: Optional[RunnableConfig] = None, + ) -> Input: + _locals[self.key] = input + return self._call_with_config(lambda x: x, input, config) + + +class GetLocalVar(Serializable, Runnable[str, Any]): + key: str + return_input_key: Optional[str] = None + + def __init__(self, key: str, **kwargs: Any) -> None: + super().__init__(key=key, **kwargs) + + def invoke(self, input: str, config: Optional[RunnableConfig] = None) -> Any: + raise NotImplementedError + + def _invoke_with_locals( + self, + input: str, + _locals: Dict[str, Any], + *, + config: Optional[RunnableConfig] = None, + ) -> Any: + if self.return_input_key is not None: + return {self.key: _locals[self.key], self.return_input_key: input} + return _locals[self.key] + + class RunnableBinding(Serializable, Runnable[Input, Output]): bound: Runnable[Input, Output] @@ -727,6 +801,17 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: return self.bound.invoke(input, config, **self.kwargs) + def _invoke_with_locals( + self, + input: Input, + _locals: Dict[str, Any], + *, + config: Optional[RunnableConfig] = None, + ) -> Output: + return self.bound._invoke_with_locals( + input, _locals, config=config, **self.kwargs + ) + async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None ) -> Output: @@ -810,6 +895,15 @@ class RouterRunnable( def invoke( self, input: RouterInput, config: Optional[RunnableConfig] = None + ) -> Output: + return self._invoke_with_locals(input, {}, config=config) + + def _invoke_with_locals( + self, + input: RouterInput, + _locals: Dict[str, Any], + *, + config: Optional[RunnableConfig] = None, ) -> Output: key = input["key"] actual_input = input["input"] @@ -817,7 +911,7 @@ class RouterRunnable( raise ValueError(f"No runnable associated with key '{key}'") runnable = self.runnables[key] - return runnable.invoke(actual_input, config) + return runnable._invoke_with_locals(actual_input, _locals, config=config) async def ainvoke( self, input: RouterInput, config: Optional[RunnableConfig] = None