From c447e9a854deef90de62bad39991f1ea55a8f29b Mon Sep 17 00:00:00 2001 From: Bagatur Date: Thu, 17 Aug 2023 15:29:00 -0700 Subject: [PATCH] cr --- .../langchain/schema/runnable/base.py | 8 ++-- .../langchain/schema/runnable/locals.py | 40 +++++++++++++++---- 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index ee3f7c11427..3f3e90ba27f 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -243,8 +243,8 @@ class Runnable(Generic[Input, Output], ABC): def _call_with_config( self, - func: Callable[[Input], Output], - input: Input, + func: Callable[[Any], Output], + input: Any, config: Optional[RunnableConfig], run_type: Optional[str] = None, ) -> Output: @@ -273,8 +273,8 @@ class Runnable(Generic[Input, Output], ABC): async def _acall_with_config( self, - func: Callable[[Input], Awaitable[Output]], - input: Input, + func: Callable[[Any], Awaitable[Output]], + input: Any, config: Optional[RunnableConfig], run_type: Optional[str] = None, ) -> Output: diff --git a/libs/langchain/langchain/schema/runnable/locals.py b/libs/langchain/langchain/schema/runnable/locals.py index 53d8f5a2ce6..cf51336dc9c 100644 --- a/libs/langchain/langchain/schema/runnable/locals.py +++ b/libs/langchain/langchain/schema/runnable/locals.py @@ -16,7 +16,7 @@ class PutLocalVar(RunnablePassthrough): stored in local state under the map keys. """ - def __init__(self, key: str, **kwargs: Any) -> None: + def __init__(self, key: Union[str, Mapping[str, str]], **kwargs: Any) -> None: super().__init__(key=key, **kwargs) def _put(self, input: Input, *, config: Optional[RunnableConfig] = None) -> None: @@ -63,13 +63,13 @@ class PutLocalVar(RunnablePassthrough): def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input: self._put(input, config=config) - return super().invoke(input, config) + return super().invoke(input, config=config) async def ainvoke( self, input: Input, config: RunnableConfig | None = None ) -> Input: self._put(input, config=config) - return await super().ainvoke(input, config) + return await super().ainvoke(input, config=config) def transform( self, input: Iterator[Input], config: RunnableConfig | None = None @@ -102,14 +102,40 @@ class GetLocalVar( def __init__(self, key: str, **kwargs: Any) -> None: super().__init__(key=key, **kwargs) + def _get(self, full_input: Dict) -> Union[Output, Dict[str, Union[Input, Output]]]: + if self.passthrough_key: + return { + self.key: full_input["locals"][self.key], + self.passthrough_key: full_input["input"], + } + else: + return full_input["locals"][self.key] + + async def _aget( + self, full_input: Dict + ) -> Union[Output, Dict[str, Union[Input, Output]]]: + return self._get(full_input) + def invoke( self, input: Input, config: Optional[RunnableConfig] = None ) -> Union[Output, Dict[str, Union[Input, Output]]]: if config is None: raise ValueError( - "PutLocalVar should only be used in a RunnableSequence, and should " + "GetLocalVar should only be used in a RunnableSequence, and should " "therefore always receive a non-null config." ) - if self.passthrough_key is not None: - return {self.key: config["_locals"][self.key], self.passthrough_key: input} - return config["_locals"][self.key] + + log_input = {"input": input, "locals": config["_locals"]} + return self._call_with_config(self._get, log_input, config) + + async def ainvoke( + self, input: Input, config: Optional[RunnableConfig] = None + ) -> Union[Output, Dict[str, Union[Input, Output]]]: + if config is None: + raise ValueError( + "GetLocalVar should only be used in a RunnableSequence, and should " + "therefore always receive a non-null config." + ) + + log_input = {"input": input, "locals": config["_locals"]} + return await self._acall_with_config(self._aget, log_input, config)