This commit is contained in:
Bagatur
2023-08-04 13:18:07 -07:00
parent bd61757423
commit 7daf31bf8e

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from typing import (
Any,
AsyncIterator,
@@ -189,6 +190,15 @@ class Runnable(Generic[Input, Output], ABC):
)
return output
def _invoke_with_locals(
self,
input: Input,
_locals: Dict[str, Any],
*,
config: Optional[RunnableConfig] = None,
) -> Output:
return self.invoke(input, config=config)
class RunnableSequence(Serializable, Runnable[Input, Output]):
first: Runnable[Input, Any]
@@ -249,6 +259,16 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
)
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
locals: Dict[str, Any] = {}
return self._invoke_with_locals(input, locals, config=config)
def _invoke_with_locals(
self,
input: Input,
_locals: Dict[str, Any],
*,
config: Optional[RunnableConfig] = None,
) -> Output:
from langchain.callbacks.manager import CallbackManager
# setup callbacks
@@ -270,10 +290,11 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
# invoke all steps in sequence
try:
for step in self.steps:
input = step.invoke(
input = step._invoke_with_locals(
input,
_locals,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
config=_patch_config(config, run_manager.get_child()),
)
# finish the root run
except (KeyboardInterrupt, Exception) as e:
@@ -595,6 +616,15 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
return self._invoke_with_locals(input, {}, config=config)
def _invoke_with_locals(
self,
input: Input,
_locals: Dict[str, Any],
*,
config: Optional[RunnableConfig] = None,
) -> Dict[str, Any]:
from langchain.callbacks.manager import CallbackManager
@@ -619,10 +649,12 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(
step.invoke,
step._invoke_with_locals,
input,
# locals are read-only in a map step
deepcopy(_locals),
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
config=_patch_config(config, run_manager.get_child()),
)
for step in steps.values()
]
@@ -709,6 +741,48 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
return self._call_with_config(lambda x: x, input, config)
class PutLocalVar(Serializable, Runnable[Input, Input]):
key: str
def __init__(self, key: str, **kwargs: Any) -> None:
super().__init__(key=key, **kwargs)
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
raise NotImplementedError
def _invoke_with_locals(
self,
input: Input,
_locals: Dict[str, Any],
*,
config: Optional[RunnableConfig] = None,
) -> Input:
_locals[self.key] = input
return self._call_with_config(lambda x: x, input, config)
class GetLocalVar(Serializable, Runnable[str, Any]):
key: str
return_input_key: Optional[str] = None
def __init__(self, key: str, **kwargs: Any) -> None:
super().__init__(key=key, **kwargs)
def invoke(self, input: str, config: Optional[RunnableConfig] = None) -> Any:
raise NotImplementedError
def _invoke_with_locals(
self,
input: str,
_locals: Dict[str, Any],
*,
config: Optional[RunnableConfig] = None,
) -> Any:
if self.return_input_key is not None:
return {self.key: _locals[self.key], self.return_input_key: input}
return _locals[self.key]
class RunnableBinding(Serializable, Runnable[Input, Output]):
bound: Runnable[Input, Output]
@@ -727,6 +801,17 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
return self.bound.invoke(input, config, **self.kwargs)
def _invoke_with_locals(
self,
input: Input,
_locals: Dict[str, Any],
*,
config: Optional[RunnableConfig] = None,
) -> Output:
return self.bound._invoke_with_locals(
input, _locals, config=config, **self.kwargs
)
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
@@ -810,6 +895,15 @@ class RouterRunnable(
def invoke(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> Output:
return self._invoke_with_locals(input, {}, config=config)
def _invoke_with_locals(
self,
input: RouterInput,
_locals: Dict[str, Any],
*,
config: Optional[RunnableConfig] = None,
) -> Output:
key = input["key"]
actual_input = input["input"]
@@ -817,7 +911,7 @@ class RouterRunnable(
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
return runnable.invoke(actual_input, config)
return runnable._invoke_with_locals(actual_input, _locals, config=config)
async def ainvoke(
self, input: RouterInput, config: Optional[RunnableConfig] = None