diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 0a36467f3f8..1b3061a259a 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -54,6 +54,7 @@ from langchain.schema.runnable.config import ( get_callback_manager_for_config, get_config_list, get_executor_for_config, + merge_configs, patch_config, ) from langchain.schema.runnable.utils import ( @@ -564,7 +565,12 @@ class Runnable(Generic[Input, Output], ABC): Bind config to a Runnable, returning a new Runnable. """ return RunnableBinding( - bound=self, config={**(config or {}), **kwargs}, kwargs={} + bound=self, + config=cast( + RunnableConfig, + {**(config or {}), **kwargs}, + ), # type: ignore[misc] + kwargs={}, ) def with_retry( @@ -2291,7 +2297,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]): kwargs: Mapping[str, Any] - config: Mapping[str, Any] = Field(default_factory=dict) + config: RunnableConfig = Field(default_factory=dict) class Config: arbitrary_types_allowed = True @@ -2301,7 +2307,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]): *, bound: Runnable[Input, Output], kwargs: Mapping[str, Any], - config: Optional[Mapping[str, Any]] = None, + config: Optional[RunnableConfig] = None, **other_kwargs: Any, ) -> None: config = config or {} @@ -2347,22 +2353,6 @@ class RunnableBinding(RunnableSerializable[Input, Output]): def get_lc_namespace(cls) -> List[str]: return cls.__module__.split(".")[:-1] - def _merge_config(self, config: Optional[RunnableConfig]) -> RunnableConfig: - copy = cast(RunnableConfig, dict(self.config)) - if config: - for key in config: - if key == "metadata": - copy[key] = {**copy.get(key, {}), **config[key]} # type: ignore - elif key == "tags": - copy[key] = (copy.get(key) or []) + config[key] # type: ignore - elif key == "configurable": - copy[key] = {**copy.get(key, {}), **config[key]} # type: ignore - else: - # Even though the keys aren't literals this is correct - # because both dicts are same type - copy[key] = config[key] or copy.get(key) # type: ignore - return copy - def bind(self, **kwargs: Any) -> Runnable[Input, Output]: return self.__class__( bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs} @@ -2377,7 +2367,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]): return self.__class__( bound=self.bound, kwargs=self.kwargs, - config={**self.config, **(config or {}), **kwargs}, + config=cast(RunnableConfig, {**self.config, **(config or {}), **kwargs}), ) def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]: @@ -2395,7 +2385,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]): ) -> Output: return self.bound.invoke( input, - self._merge_config(config), + merge_configs(self.config, config), **{**self.kwargs, **kwargs}, ) @@ -2407,7 +2397,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]): ) -> Output: return await self.bound.ainvoke( input, - self._merge_config(config), + merge_configs(self.config, config), **{**self.kwargs, **kwargs}, ) @@ -2421,11 +2411,12 @@ class RunnableBinding(RunnableSerializable[Input, Output]): ) -> List[Output]: if isinstance(config, list): configs = cast( - List[RunnableConfig], [self._merge_config(conf) for conf in config] + List[RunnableConfig], + [merge_configs(self.config, conf) for conf in config], ) else: configs = [ - patch_config(self._merge_config(config), copy_locals=True) + patch_config(merge_configs(self.config, config), copy_locals=True) for _ in range(len(inputs)) ] return self.bound.batch( @@ -2445,11 +2436,12 @@ class RunnableBinding(RunnableSerializable[Input, Output]): ) -> List[Output]: if isinstance(config, list): configs = cast( - List[RunnableConfig], [self._merge_config(conf) for conf in config] + List[RunnableConfig], + [merge_configs(self.config, conf) for conf in config], ) else: configs = [ - patch_config(self._merge_config(config), copy_locals=True) + patch_config(merge_configs(self.config, config), copy_locals=True) for _ in range(len(inputs)) ] return await self.bound.abatch( @@ -2467,7 +2459,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]): ) -> Iterator[Output]: yield from self.bound.stream( input, - self._merge_config(config), + merge_configs(self.config, config), **{**self.kwargs, **kwargs}, ) @@ -2479,7 +2471,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]): ) -> AsyncIterator[Output]: async for item in self.bound.astream( input, - self._merge_config(config), + merge_configs(self.config, config), **{**self.kwargs, **kwargs}, ): yield item @@ -2492,7 +2484,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]): ) -> Iterator[Output]: yield from self.bound.transform( input, - self._merge_config(config), + merge_configs(self.config, config), **{**self.kwargs, **kwargs}, ) @@ -2504,7 +2496,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]): ) -> AsyncIterator[Output]: async for item in self.bound.atransform( input, - self._merge_config(config), + merge_configs(self.config, config), **{**self.kwargs, **kwargs}, ): yield item diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index 71eb7428e06..1b720fb5b6e 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -157,6 +157,31 @@ def patch_config( return config +def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig: + base: RunnableConfig = {} + # Even though the keys aren't literals this is correct + # because both dicts are same type + for config in (c for c in configs if c is not None): + for key in config: + if key == "metadata": + base[key] = { # type: ignore + **base.get(key, {}), # type: ignore + **(config.get(key) or {}), # type: ignore + } + elif key == "tags": + base[key] = list( # type: ignore + set(base.get(key, []) + (config.get(key) or [])), # type: ignore + ) + elif key == "configurable": + base[key] = { # type: ignore + **base.get(key, {}), # type: ignore + **(config.get(key) or {}), # type: ignore + } + else: + base[key] = config[key] or base.get(key) # type: ignore + return base + + def call_func_with_variable_args( func: Union[ Callable[[Input], Output],