From bd80cad6dbd045e36afe4be4071d1ef612ff9ea9 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Thu, 17 Aug 2023 13:52:19 -0700 Subject: [PATCH] add --- .../langchain/schema/runnable/locals.py | 115 ++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 libs/langchain/langchain/schema/runnable/locals.py diff --git a/libs/langchain/langchain/schema/runnable/locals.py b/libs/langchain/langchain/schema/runnable/locals.py new file mode 100644 index 00000000000..53d8f5a2ce6 --- /dev/null +++ b/libs/langchain/langchain/schema/runnable/locals.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union + +from langchain.load.serializable import Serializable +from langchain.schema.runnable import Runnable, RunnableConfig, RunnablePassthrough +from langchain.schema.runnable.base import Input, Output + + +class PutLocalVar(RunnablePassthrough): + key: Union[str, Mapping[str, str]] + """The key(s) to use for storing the input variable(s) in local state. + + If a string is provided then the entire input is stored under that key. If a + Mapping is provided, then the map values are gotten from the input and + stored in local state under the map keys. + """ + + def __init__(self, key: str, **kwargs: Any) -> None: + super().__init__(key=key, **kwargs) + + def _put(self, input: Input, *, config: Optional[RunnableConfig] = None) -> None: + if config is None: + raise ValueError( + "PutLocalVar should only be used in a RunnableSequence, and should " + "therefore always receive a non-null config." + ) + if isinstance(self.key, str): + config["_locals"][self.key] = input + elif isinstance(input, Mapping): + for input_key, put_key in self.key.items(): + config["_locals"][put_key] = input[input_key] + else: + raise TypeError( + f"`key` should be a string or Mapping[str, str], received type " + f"{(type(self.key))}." + ) + + def _concat_put( + self, input: Input, *, config: Optional[RunnableConfig] = None + ) -> None: + if config is None: + raise ValueError( + "PutLocalVar should only be used in a RunnableSequence, and should " + "therefore always receive a non-null config." + ) + if isinstance(self.key, str): + if self.key not in config["_locals"]: + config["_locals"][self.key] = input + else: + config["_locals"][self.key] += input + elif isinstance(input, Mapping): + for input_key, put_key in self.key.items(): + if put_key not in config["_locals"]: + config["_locals"][put_key] = input + else: + config["_locals"][put_key] += input + else: + raise TypeError( + f"`key` should be a string or Mapping[str, str], received type " + f"{(type(self.key))}." + ) + + def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input: + self._put(input, config=config) + return super().invoke(input, config) + + async def ainvoke( + self, input: Input, config: RunnableConfig | None = None + ) -> Input: + self._put(input, config=config) + return await super().ainvoke(input, config) + + def transform( + self, input: Iterator[Input], config: RunnableConfig | None = None + ) -> Iterator[Input]: + for chunk in super().transform(input, config=config): + self._concat_put(input, config=config) + yield chunk + + async def atransform( + self, input: AsyncIterator[Input], config: RunnableConfig | None = None + ) -> AsyncIterator[Input]: + async for chunk in super().atransform(input, config=config): + self._concat_put(input, config=config) + yield chunk + + +class GetLocalVar( + Serializable, Runnable[Input, Union[Output, Dict[str, Union[Input, Output]]]] +): + key: str + """The key to extract from the local state.""" + passthrough_key: Optional[str] = None + """The key to use for passing through the invocation input. + + If None, then only the value retrieved from local state is returned. Otherwise a + dictionary ``{self.key: <>, self.passthrough_key: <>}`` + is returned. + """ + + def __init__(self, key: str, **kwargs: Any) -> None: + super().__init__(key=key, **kwargs) + + def invoke( + self, input: Input, config: Optional[RunnableConfig] = None + ) -> Union[Output, Dict[str, Union[Input, Output]]]: + if config is None: + raise ValueError( + "PutLocalVar should only be used in a RunnableSequence, and should " + "therefore always receive a non-null config." + ) + if self.passthrough_key is not None: + return {self.key: config["_locals"][self.key], self.passthrough_key: input} + return config["_locals"][self.key]