This commit is contained in:
Bagatur 2023-08-17 16:22:12 -07:00
parent c447e9a854
commit 6b0a849f59
3 changed files with 30 additions and 12 deletions

View File

@ -7,10 +7,13 @@ from langchain.schema.runnable.base import (
RunnableWithFallbacks, RunnableWithFallbacks,
) )
from langchain.schema.runnable.config import RunnableConfig from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.locals import GetLocalVar, PutLocalVar
from langchain.schema.runnable.passthrough import RunnablePassthrough from langchain.schema.runnable.passthrough import RunnablePassthrough
from langchain.schema.runnable.router import RouterInput, RouterRunnable from langchain.schema.runnable.router import RouterInput, RouterRunnable
__all__ = [ __all__ = [
"GetLocalVar",
"PutLocalVar",
"RouterInput", "RouterInput",
"RouterRunnable", "RouterRunnable",
"Runnable", "Runnable",

View File

@ -238,7 +238,10 @@ class Runnable(Generic[Input, Output], ABC):
return ( return (
config config
if isinstance(config, list) if isinstance(config, list)
else [deepcopy(config) if config is not None else {} for _ in range(length)] else [
deepcopy(config) if config is not None else _empty_config()
for _ in range(length)
]
) )
def _call_with_config( def _call_with_config(
@ -750,7 +753,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None
) -> Output: ) -> Output:
# setup callbacks # setup callbacks
config = config or {} config = config or _empty_config()
callback_manager = _get_async_callback_manager(config) callback_manager = _get_async_callback_manager(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
@ -896,7 +899,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None
) -> Iterator[Output]: ) -> Iterator[Output]:
# setup callbacks # setup callbacks
config = config or {} config = config or _empty_config()
callback_manager = _get_callback_manager(config) callback_manager = _get_callback_manager(config)
# start the root run # start the root run
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
@ -963,7 +966,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
# setup callbacks # setup callbacks
config = config or {} config = config or _empty_config()
callback_manager = _get_async_callback_manager(config) callback_manager = _get_async_callback_manager(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(

View File

@ -3,8 +3,9 @@ from __future__ import annotations
from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.schema.runnable import Runnable, RunnableConfig, RunnablePassthrough from langchain.schema.runnable.base import Input, Output, Runnable
from langchain.schema.runnable.base import Input, Output from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.passthrough import RunnablePassthrough
class PutLocalVar(RunnablePassthrough): class PutLocalVar(RunnablePassthrough):
@ -27,7 +28,12 @@ class PutLocalVar(RunnablePassthrough):
) )
if isinstance(self.key, str): if isinstance(self.key, str):
config["_locals"][self.key] = input config["_locals"][self.key] = input
elif isinstance(input, Mapping): elif isinstance(self.key, Mapping):
if not isinstance(input, Mapping):
raise TypeError(
f"Received key of type Mapping but input of type {type(input)}. "
f"input is expected to be of type Mapping when key is Mapping."
)
for input_key, put_key in self.key.items(): for input_key, put_key in self.key.items():
config["_locals"][put_key] = input[input_key] config["_locals"][put_key] = input[input_key]
else: else:
@ -44,17 +50,23 @@ class PutLocalVar(RunnablePassthrough):
"PutLocalVar should only be used in a RunnableSequence, and should " "PutLocalVar should only be used in a RunnableSequence, and should "
"therefore always receive a non-null config." "therefore always receive a non-null config."
) )
print(config)
if isinstance(self.key, str): if isinstance(self.key, str):
if self.key not in config["_locals"]: if self.key not in config["_locals"]:
config["_locals"][self.key] = input config["_locals"][self.key] = input
else: else:
config["_locals"][self.key] += input config["_locals"][self.key] += input
elif isinstance(input, Mapping): elif isinstance(self.key, Mapping):
if not isinstance(input, Mapping):
raise TypeError(
f"Received key of type Mapping but input of type {type(input)}. "
f"input is expected to be of type Mapping when key is Mapping."
)
for input_key, put_key in self.key.items(): for input_key, put_key in self.key.items():
if put_key not in config["_locals"]: if put_key not in config["_locals"]:
config["_locals"][put_key] = input config["_locals"][put_key] = input[input_key]
else: else:
config["_locals"][put_key] += input config["_locals"][put_key] += input[input_key]
else: else:
raise TypeError( raise TypeError(
f"`key` should be a string or Mapping[str, str], received type " f"`key` should be a string or Mapping[str, str], received type "
@ -75,14 +87,14 @@ class PutLocalVar(RunnablePassthrough):
self, input: Iterator[Input], config: RunnableConfig | None = None self, input: Iterator[Input], config: RunnableConfig | None = None
) -> Iterator[Input]: ) -> Iterator[Input]:
for chunk in super().transform(input, config=config): for chunk in super().transform(input, config=config):
self._concat_put(input, config=config) self._concat_put(chunk, config=config)
yield chunk yield chunk
async def atransform( async def atransform(
self, input: AsyncIterator[Input], config: RunnableConfig | None = None self, input: AsyncIterator[Input], config: RunnableConfig | None = None
) -> AsyncIterator[Input]: ) -> AsyncIterator[Input]:
async for chunk in super().atransform(input, config=config): async for chunk in super().atransform(input, config=config):
self._concat_put(input, config=config) self._concat_put(chunk, config=config)
yield chunk yield chunk