diff --git a/libs/langchain/langchain/schema/runnable/__init__.py b/libs/langchain/langchain/schema/runnable/__init__.py index 0dbabd1579d..bae6aebb024 100644 --- a/libs/langchain/langchain/schema/runnable/__init__.py +++ b/libs/langchain/langchain/schema/runnable/__init__.py @@ -7,10 +7,13 @@ from langchain.schema.runnable.base import ( RunnableWithFallbacks, ) from langchain.schema.runnable.config import RunnableConfig +from langchain.schema.runnable.locals import GetLocalVar, PutLocalVar from langchain.schema.runnable.passthrough import RunnablePassthrough from langchain.schema.runnable.router import RouterInput, RouterRunnable __all__ = [ + "GetLocalVar", + "PutLocalVar", "RouterInput", "RouterRunnable", "Runnable", diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 3f3e90ba27f..704a518cde8 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -238,7 +238,10 @@ class Runnable(Generic[Input, Output], ABC): return ( config if isinstance(config, list) - else [deepcopy(config) if config is not None else {} for _ in range(length)] + else [ + deepcopy(config) if config is not None else _empty_config() + for _ in range(length) + ] ) def _call_with_config( @@ -750,7 +753,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None ) -> Output: # setup callbacks - config = config or {} + config = config or _empty_config() callback_manager = _get_async_callback_manager(config) # start the root run run_manager = await callback_manager.on_chain_start( @@ -896,7 +899,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None ) -> Iterator[Output]: # setup callbacks - config = config or {} + config = config or _empty_config() callback_manager = _get_callback_manager(config) # start the root run run_manager = callback_manager.on_chain_start( @@ -963,7 +966,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None ) -> AsyncIterator[Output]: # setup callbacks - config = config or {} + config = config or _empty_config() callback_manager = _get_async_callback_manager(config) # start the root run run_manager = await callback_manager.on_chain_start( diff --git a/libs/langchain/langchain/schema/runnable/locals.py b/libs/langchain/langchain/schema/runnable/locals.py index cf51336dc9c..65e63507bc4 100644 --- a/libs/langchain/langchain/schema/runnable/locals.py +++ b/libs/langchain/langchain/schema/runnable/locals.py @@ -3,8 +3,9 @@ from __future__ import annotations from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union from langchain.load.serializable import Serializable -from langchain.schema.runnable import Runnable, RunnableConfig, RunnablePassthrough -from langchain.schema.runnable.base import Input, Output +from langchain.schema.runnable.base import Input, Output, Runnable +from langchain.schema.runnable.config import RunnableConfig +from langchain.schema.runnable.passthrough import RunnablePassthrough class PutLocalVar(RunnablePassthrough): @@ -27,7 +28,12 @@ class PutLocalVar(RunnablePassthrough): ) if isinstance(self.key, str): config["_locals"][self.key] = input - elif isinstance(input, Mapping): + elif isinstance(self.key, Mapping): + if not isinstance(input, Mapping): + raise TypeError( + f"Received key of type Mapping but input of type {type(input)}. " + f"input is expected to be of type Mapping when key is Mapping." + ) for input_key, put_key in self.key.items(): config["_locals"][put_key] = input[input_key] else: @@ -44,17 +50,23 @@ class PutLocalVar(RunnablePassthrough): "PutLocalVar should only be used in a RunnableSequence, and should " "therefore always receive a non-null config." ) + print(config) if isinstance(self.key, str): if self.key not in config["_locals"]: config["_locals"][self.key] = input else: config["_locals"][self.key] += input - elif isinstance(input, Mapping): + elif isinstance(self.key, Mapping): + if not isinstance(input, Mapping): + raise TypeError( + f"Received key of type Mapping but input of type {type(input)}. " + f"input is expected to be of type Mapping when key is Mapping." + ) for input_key, put_key in self.key.items(): if put_key not in config["_locals"]: - config["_locals"][put_key] = input + config["_locals"][put_key] = input[input_key] else: - config["_locals"][put_key] += input + config["_locals"][put_key] += input[input_key] else: raise TypeError( f"`key` should be a string or Mapping[str, str], received type " @@ -75,14 +87,14 @@ class PutLocalVar(RunnablePassthrough): self, input: Iterator[Input], config: RunnableConfig | None = None ) -> Iterator[Input]: for chunk in super().transform(input, config=config): - self._concat_put(input, config=config) + self._concat_put(chunk, config=config) yield chunk async def atransform( self, input: AsyncIterator[Input], config: RunnableConfig | None = None ) -> AsyncIterator[Input]: async for chunk in super().atransform(input, config=config): - self._concat_put(input, config=config) + self._concat_put(chunk, config=config) yield chunk