This commit is contained in:
Bagatur 2023-08-17 16:22:12 -07:00
parent c447e9a854
commit 6b0a849f59
3 changed files with 30 additions and 12 deletions

View File

@ -7,10 +7,13 @@ from langchain.schema.runnable.base import (
RunnableWithFallbacks,
)
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.locals import GetLocalVar, PutLocalVar
from langchain.schema.runnable.passthrough import RunnablePassthrough
from langchain.schema.runnable.router import RouterInput, RouterRunnable
__all__ = [
"GetLocalVar",
"PutLocalVar",
"RouterInput",
"RouterRunnable",
"Runnable",

View File

@ -238,7 +238,10 @@ class Runnable(Generic[Input, Output], ABC):
return (
config
if isinstance(config, list)
else [deepcopy(config) if config is not None else {} for _ in range(length)]
else [
deepcopy(config) if config is not None else _empty_config()
for _ in range(length)
]
)
def _call_with_config(
@ -750,7 +753,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
# setup callbacks
config = config or {}
config = config or _empty_config()
callback_manager = _get_async_callback_manager(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
@ -896,7 +899,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self, input: Input, config: Optional[RunnableConfig] = None
) -> Iterator[Output]:
# setup callbacks
config = config or {}
config = config or _empty_config()
callback_manager = _get_callback_manager(config)
# start the root run
run_manager = callback_manager.on_chain_start(
@ -963,7 +966,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self, input: Input, config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]:
# setup callbacks
config = config or {}
config = config or _empty_config()
callback_manager = _get_async_callback_manager(config)
# start the root run
run_manager = await callback_manager.on_chain_start(

View File

@ -3,8 +3,9 @@ from __future__ import annotations
from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union
from langchain.load.serializable import Serializable
from langchain.schema.runnable import Runnable, RunnableConfig, RunnablePassthrough
from langchain.schema.runnable.base import Input, Output
from langchain.schema.runnable.base import Input, Output, Runnable
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.passthrough import RunnablePassthrough
class PutLocalVar(RunnablePassthrough):
@ -27,7 +28,12 @@ class PutLocalVar(RunnablePassthrough):
)
if isinstance(self.key, str):
config["_locals"][self.key] = input
elif isinstance(input, Mapping):
elif isinstance(self.key, Mapping):
if not isinstance(input, Mapping):
raise TypeError(
f"Received key of type Mapping but input of type {type(input)}. "
f"input is expected to be of type Mapping when key is Mapping."
)
for input_key, put_key in self.key.items():
config["_locals"][put_key] = input[input_key]
else:
@ -44,17 +50,23 @@ class PutLocalVar(RunnablePassthrough):
"PutLocalVar should only be used in a RunnableSequence, and should "
"therefore always receive a non-null config."
)
print(config)
if isinstance(self.key, str):
if self.key not in config["_locals"]:
config["_locals"][self.key] = input
else:
config["_locals"][self.key] += input
elif isinstance(input, Mapping):
elif isinstance(self.key, Mapping):
if not isinstance(input, Mapping):
raise TypeError(
f"Received key of type Mapping but input of type {type(input)}. "
f"input is expected to be of type Mapping when key is Mapping."
)
for input_key, put_key in self.key.items():
if put_key not in config["_locals"]:
config["_locals"][put_key] = input
config["_locals"][put_key] = input[input_key]
else:
config["_locals"][put_key] += input
config["_locals"][put_key] += input[input_key]
else:
raise TypeError(
f"`key` should be a string or Mapping[str, str], received type "
@ -75,14 +87,14 @@ class PutLocalVar(RunnablePassthrough):
self, input: Iterator[Input], config: RunnableConfig | None = None
) -> Iterator[Input]:
for chunk in super().transform(input, config=config):
self._concat_put(input, config=config)
self._concat_put(chunk, config=config)
yield chunk
async def atransform(
self, input: AsyncIterator[Input], config: RunnableConfig | None = None
) -> AsyncIterator[Input]:
async for chunk in super().atransform(input, config=config):
self._concat_put(input, config=config)
self._concat_put(chunk, config=config)
yield chunk