This commit is contained in:
Bagatur 2023-08-17 13:52:19 -07:00
parent 8c1a528c71
commit bd80cad6db

View File

@ -0,0 +1,115 @@
from __future__ import annotations
from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union
from langchain.load.serializable import Serializable
from langchain.schema.runnable import Runnable, RunnableConfig, RunnablePassthrough
from langchain.schema.runnable.base import Input, Output
class PutLocalVar(RunnablePassthrough):
key: Union[str, Mapping[str, str]]
"""The key(s) to use for storing the input variable(s) in local state.
If a string is provided then the entire input is stored under that key. If a
Mapping is provided, then the map values are gotten from the input and
stored in local state under the map keys.
"""
def __init__(self, key: str, **kwargs: Any) -> None:
super().__init__(key=key, **kwargs)
def _put(self, input: Input, *, config: Optional[RunnableConfig] = None) -> None:
if config is None:
raise ValueError(
"PutLocalVar should only be used in a RunnableSequence, and should "
"therefore always receive a non-null config."
)
if isinstance(self.key, str):
config["_locals"][self.key] = input
elif isinstance(input, Mapping):
for input_key, put_key in self.key.items():
config["_locals"][put_key] = input[input_key]
else:
raise TypeError(
f"`key` should be a string or Mapping[str, str], received type "
f"{(type(self.key))}."
)
def _concat_put(
self, input: Input, *, config: Optional[RunnableConfig] = None
) -> None:
if config is None:
raise ValueError(
"PutLocalVar should only be used in a RunnableSequence, and should "
"therefore always receive a non-null config."
)
if isinstance(self.key, str):
if self.key not in config["_locals"]:
config["_locals"][self.key] = input
else:
config["_locals"][self.key] += input
elif isinstance(input, Mapping):
for input_key, put_key in self.key.items():
if put_key not in config["_locals"]:
config["_locals"][put_key] = input
else:
config["_locals"][put_key] += input
else:
raise TypeError(
f"`key` should be a string or Mapping[str, str], received type "
f"{(type(self.key))}."
)
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
self._put(input, config=config)
return super().invoke(input, config)
async def ainvoke(
self, input: Input, config: RunnableConfig | None = None
) -> Input:
self._put(input, config=config)
return await super().ainvoke(input, config)
def transform(
self, input: Iterator[Input], config: RunnableConfig | None = None
) -> Iterator[Input]:
for chunk in super().transform(input, config=config):
self._concat_put(input, config=config)
yield chunk
async def atransform(
self, input: AsyncIterator[Input], config: RunnableConfig | None = None
) -> AsyncIterator[Input]:
async for chunk in super().atransform(input, config=config):
self._concat_put(input, config=config)
yield chunk
class GetLocalVar(
Serializable, Runnable[Input, Union[Output, Dict[str, Union[Input, Output]]]]
):
key: str
"""The key to extract from the local state."""
passthrough_key: Optional[str] = None
"""The key to use for passing through the invocation input.
If None, then only the value retrieved from local state is returned. Otherwise a
dictionary ``{self.key: <<retrieved_value>>, self.passthrough_key: <<input>>}``
is returned.
"""
def __init__(self, key: str, **kwargs: Any) -> None:
super().__init__(key=key, **kwargs)
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Union[Output, Dict[str, Union[Input, Output]]]:
if config is None:
raise ValueError(
"PutLocalVar should only be used in a RunnableSequence, and should "
"therefore always receive a non-null config."
)
if self.passthrough_key is not None:
return {self.key: config["_locals"][self.key], self.passthrough_key: input}
return config["_locals"][self.key]