mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-08 00:28:47 +00:00
add
This commit is contained in:
parent
8c1a528c71
commit
bd80cad6db
115
libs/langchain/langchain/schema/runnable/locals.py
Normal file
115
libs/langchain/langchain/schema/runnable/locals.py
Normal file
@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.runnable import Runnable, RunnableConfig, RunnablePassthrough
|
||||
from langchain.schema.runnable.base import Input, Output
|
||||
|
||||
|
||||
class PutLocalVar(RunnablePassthrough):
|
||||
key: Union[str, Mapping[str, str]]
|
||||
"""The key(s) to use for storing the input variable(s) in local state.
|
||||
|
||||
If a string is provided then the entire input is stored under that key. If a
|
||||
Mapping is provided, then the map values are gotten from the input and
|
||||
stored in local state under the map keys.
|
||||
"""
|
||||
|
||||
def __init__(self, key: str, **kwargs: Any) -> None:
|
||||
super().__init__(key=key, **kwargs)
|
||||
|
||||
def _put(self, input: Input, *, config: Optional[RunnableConfig] = None) -> None:
|
||||
if config is None:
|
||||
raise ValueError(
|
||||
"PutLocalVar should only be used in a RunnableSequence, and should "
|
||||
"therefore always receive a non-null config."
|
||||
)
|
||||
if isinstance(self.key, str):
|
||||
config["_locals"][self.key] = input
|
||||
elif isinstance(input, Mapping):
|
||||
for input_key, put_key in self.key.items():
|
||||
config["_locals"][put_key] = input[input_key]
|
||||
else:
|
||||
raise TypeError(
|
||||
f"`key` should be a string or Mapping[str, str], received type "
|
||||
f"{(type(self.key))}."
|
||||
)
|
||||
|
||||
def _concat_put(
|
||||
self, input: Input, *, config: Optional[RunnableConfig] = None
|
||||
) -> None:
|
||||
if config is None:
|
||||
raise ValueError(
|
||||
"PutLocalVar should only be used in a RunnableSequence, and should "
|
||||
"therefore always receive a non-null config."
|
||||
)
|
||||
if isinstance(self.key, str):
|
||||
if self.key not in config["_locals"]:
|
||||
config["_locals"][self.key] = input
|
||||
else:
|
||||
config["_locals"][self.key] += input
|
||||
elif isinstance(input, Mapping):
|
||||
for input_key, put_key in self.key.items():
|
||||
if put_key not in config["_locals"]:
|
||||
config["_locals"][put_key] = input
|
||||
else:
|
||||
config["_locals"][put_key] += input
|
||||
else:
|
||||
raise TypeError(
|
||||
f"`key` should be a string or Mapping[str, str], received type "
|
||||
f"{(type(self.key))}."
|
||||
)
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
||||
self._put(input, config=config)
|
||||
return super().invoke(input, config)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: RunnableConfig | None = None
|
||||
) -> Input:
|
||||
self._put(input, config=config)
|
||||
return await super().ainvoke(input, config)
|
||||
|
||||
def transform(
|
||||
self, input: Iterator[Input], config: RunnableConfig | None = None
|
||||
) -> Iterator[Input]:
|
||||
for chunk in super().transform(input, config=config):
|
||||
self._concat_put(input, config=config)
|
||||
yield chunk
|
||||
|
||||
async def atransform(
|
||||
self, input: AsyncIterator[Input], config: RunnableConfig | None = None
|
||||
) -> AsyncIterator[Input]:
|
||||
async for chunk in super().atransform(input, config=config):
|
||||
self._concat_put(input, config=config)
|
||||
yield chunk
|
||||
|
||||
|
||||
class GetLocalVar(
|
||||
Serializable, Runnable[Input, Union[Output, Dict[str, Union[Input, Output]]]]
|
||||
):
|
||||
key: str
|
||||
"""The key to extract from the local state."""
|
||||
passthrough_key: Optional[str] = None
|
||||
"""The key to use for passing through the invocation input.
|
||||
|
||||
If None, then only the value retrieved from local state is returned. Otherwise a
|
||||
dictionary ``{self.key: <<retrieved_value>>, self.passthrough_key: <<input>>}``
|
||||
is returned.
|
||||
"""
|
||||
|
||||
def __init__(self, key: str, **kwargs: Any) -> None:
|
||||
super().__init__(key=key, **kwargs)
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Union[Output, Dict[str, Union[Input, Output]]]:
|
||||
if config is None:
|
||||
raise ValueError(
|
||||
"PutLocalVar should only be used in a RunnableSequence, and should "
|
||||
"therefore always receive a non-null config."
|
||||
)
|
||||
if self.passthrough_key is not None:
|
||||
return {self.key: config["_locals"][self.key], self.passthrough_key: input}
|
||||
return config["_locals"][self.key]
|
Loading…
Reference in New Issue
Block a user