mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
rfc
This commit is contained in:
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import deepcopy
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
@@ -189,6 +190,15 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
)
|
||||
return output
|
||||
|
||||
def _invoke_with_locals(
|
||||
self,
|
||||
input: Input,
|
||||
_locals: Dict[str, Any],
|
||||
*,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
) -> Output:
|
||||
return self.invoke(input, config=config)
|
||||
|
||||
|
||||
class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
first: Runnable[Input, Any]
|
||||
@@ -249,6 +259,16 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
)
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
locals: Dict[str, Any] = {}
|
||||
return self._invoke_with_locals(input, locals, config=config)
|
||||
|
||||
def _invoke_with_locals(
|
||||
self,
|
||||
input: Input,
|
||||
_locals: Dict[str, Any],
|
||||
*,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
) -> Output:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
# setup callbacks
|
||||
@@ -270,10 +290,11 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
# invoke all steps in sequence
|
||||
try:
|
||||
for step in self.steps:
|
||||
input = step.invoke(
|
||||
input = step._invoke_with_locals(
|
||||
input,
|
||||
_locals,
|
||||
# mark each step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
config=_patch_config(config, run_manager.get_child()),
|
||||
)
|
||||
# finish the root run
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
@@ -595,6 +616,15 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Dict[str, Any]:
|
||||
return self._invoke_with_locals(input, {}, config=config)
|
||||
|
||||
def _invoke_with_locals(
|
||||
self,
|
||||
input: Input,
|
||||
_locals: Dict[str, Any],
|
||||
*,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
) -> Dict[str, Any]:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
@@ -619,10 +649,12 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
step.invoke,
|
||||
step._invoke_with_locals,
|
||||
input,
|
||||
# locals are read-only in a map step
|
||||
deepcopy(_locals),
|
||||
# mark each step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
config=_patch_config(config, run_manager.get_child()),
|
||||
)
|
||||
for step in steps.values()
|
||||
]
|
||||
@@ -709,6 +741,48 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
||||
return self._call_with_config(lambda x: x, input, config)
|
||||
|
||||
|
||||
class PutLocalVar(Serializable, Runnable[Input, Input]):
|
||||
key: str
|
||||
|
||||
def __init__(self, key: str, **kwargs: Any) -> None:
|
||||
super().__init__(key=key, **kwargs)
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
||||
raise NotImplementedError
|
||||
|
||||
def _invoke_with_locals(
|
||||
self,
|
||||
input: Input,
|
||||
_locals: Dict[str, Any],
|
||||
*,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
) -> Input:
|
||||
_locals[self.key] = input
|
||||
return self._call_with_config(lambda x: x, input, config)
|
||||
|
||||
|
||||
class GetLocalVar(Serializable, Runnable[str, Any]):
|
||||
key: str
|
||||
return_input_key: Optional[str] = None
|
||||
|
||||
def __init__(self, key: str, **kwargs: Any) -> None:
|
||||
super().__init__(key=key, **kwargs)
|
||||
|
||||
def invoke(self, input: str, config: Optional[RunnableConfig] = None) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def _invoke_with_locals(
|
||||
self,
|
||||
input: str,
|
||||
_locals: Dict[str, Any],
|
||||
*,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
) -> Any:
|
||||
if self.return_input_key is not None:
|
||||
return {self.key: _locals[self.key], self.return_input_key: input}
|
||||
return _locals[self.key]
|
||||
|
||||
|
||||
class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
bound: Runnable[Input, Output]
|
||||
|
||||
@@ -727,6 +801,17 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
return self.bound.invoke(input, config, **self.kwargs)
|
||||
|
||||
def _invoke_with_locals(
|
||||
self,
|
||||
input: Input,
|
||||
_locals: Dict[str, Any],
|
||||
*,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
) -> Output:
|
||||
return self.bound._invoke_with_locals(
|
||||
input, _locals, config=config, **self.kwargs
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
@@ -810,6 +895,15 @@ class RouterRunnable(
|
||||
|
||||
def invoke(
|
||||
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
return self._invoke_with_locals(input, {}, config=config)
|
||||
|
||||
def _invoke_with_locals(
|
||||
self,
|
||||
input: RouterInput,
|
||||
_locals: Dict[str, Any],
|
||||
*,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
) -> Output:
|
||||
key = input["key"]
|
||||
actual_input = input["input"]
|
||||
@@ -817,7 +911,7 @@ class RouterRunnable(
|
||||
raise ValueError(f"No runnable associated with key '{key}'")
|
||||
|
||||
runnable = self.runnables[key]
|
||||
return runnable.invoke(actual_input, config)
|
||||
return runnable._invoke_with_locals(actual_input, _locals, config=config)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||
|
||||
Reference in New Issue
Block a user