From b0d5882fe19d361b0ed0e01bb1ee5f083562fa4e Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 17 Oct 2023 13:22:07 +0100 Subject: [PATCH] Export merge_configs function --- .../langchain/schema/runnable/base.py | 52 ++++++++----------- .../langchain/schema/runnable/config.py | 17 ++++++ 2 files changed, 39 insertions(+), 30 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index c053d2f6257..014b2857ff6 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -54,6 +54,7 @@ from langchain.schema.runnable.config import ( get_callback_manager_for_config, get_config_list, get_executor_for_config, + merge_configs, patch_config, ) from langchain.schema.runnable.utils import ( @@ -564,7 +565,12 @@ class Runnable(Generic[Input, Output], ABC): Bind config to a Runnable, returning a new Runnable. """ return RunnableBinding( - bound=self, config={**(config or {}), **kwargs}, kwargs={} + bound=self, + config=cast( + RunnableConfig, + {**(config or {}), **kwargs}, + ), # type: ignore[misc] + kwargs={}, ) def with_retry( @@ -2291,7 +2297,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]): kwargs: Mapping[str, Any] - config: Mapping[str, Any] = Field(default_factory=dict) + config: RunnableConfig = Field(default_factory=dict) class Config: arbitrary_types_allowed = True @@ -2301,7 +2307,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]): *, bound: Runnable[Input, Output], kwargs: Mapping[str, Any], - config: Optional[Mapping[str, Any]] = None, + config: Optional[RunnableConfig] = None, **other_kwargs: Any, ) -> None: config = config or {} @@ -2346,22 +2352,6 @@ class RunnableBinding(RunnableSerializable[Input, Output]): def get_lc_namespace(cls) -> List[str]: return cls.__module__.split(".")[:-1] - def _merge_config(self, config: Optional[RunnableConfig]) -> RunnableConfig: - copy = cast(RunnableConfig, dict(self.config)) - if config: - for key in config: - if key == "metadata": - copy[key] = {**copy.get(key, {}), **config[key]} # type: ignore - elif key == "tags": - copy[key] = (copy.get(key) or []) + config[key] # type: ignore - elif key == "configurable": - copy[key] = {**copy.get(key, {}), **config[key]} # type: ignore - else: - # Even though the keys aren't literals this is correct - # because both dicts are same type - copy[key] = config[key] or copy.get(key) # type: ignore - return copy - def bind(self, **kwargs: Any) -> Runnable[Input, Output]: return self.__class__( bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs} @@ -2376,7 +2366,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]): return self.__class__( bound=self.bound, kwargs=self.kwargs, - config={**self.config, **(config or {}), **kwargs}, + config=cast(RunnableConfig, {**self.config, **(config or {}), **kwargs}), ) def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]: @@ -2394,7 +2384,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]): ) -> Output: return self.bound.invoke( input, - self._merge_config(config), + merge_configs(self.config, config), **{**self.kwargs, **kwargs}, ) @@ -2406,7 +2396,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]): ) -> Output: return await self.bound.ainvoke( input, - self._merge_config(config), + merge_configs(self.config, config), **{**self.kwargs, **kwargs}, ) @@ -2420,11 +2410,12 @@ class RunnableBinding(RunnableSerializable[Input, Output]): ) -> List[Output]: if isinstance(config, list): configs = cast( - List[RunnableConfig], [self._merge_config(conf) for conf in config] + List[RunnableConfig], + [merge_configs(self.config, conf) for conf in config], ) else: configs = [ - patch_config(self._merge_config(config), copy_locals=True) + patch_config(merge_configs(self.config, config), copy_locals=True) for _ in range(len(inputs)) ] return self.bound.batch( @@ -2444,11 +2435,12 @@ class RunnableBinding(RunnableSerializable[Input, Output]): ) -> List[Output]: if isinstance(config, list): configs = cast( - List[RunnableConfig], [self._merge_config(conf) for conf in config] + List[RunnableConfig], + [merge_configs(self.config, conf) for conf in config], ) else: configs = [ - patch_config(self._merge_config(config), copy_locals=True) + patch_config(merge_configs(self.config, config), copy_locals=True) for _ in range(len(inputs)) ] return await self.bound.abatch( @@ -2466,7 +2458,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]): ) -> Iterator[Output]: yield from self.bound.stream( input, - self._merge_config(config), + merge_configs(self.config, config), **{**self.kwargs, **kwargs}, ) @@ -2478,7 +2470,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]): ) -> AsyncIterator[Output]: async for item in self.bound.astream( input, - self._merge_config(config), + merge_configs(self.config, config), **{**self.kwargs, **kwargs}, ): yield item @@ -2491,7 +2483,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]): ) -> Iterator[Output]: yield from self.bound.transform( input, - self._merge_config(config), + merge_configs(self.config, config), **{**self.kwargs, **kwargs}, ) @@ -2503,7 +2495,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]): ) -> AsyncIterator[Output]: async for item in self.bound.atransform( input, - self._merge_config(config), + merge_configs(self.config, config), **{**self.kwargs, **kwargs}, ): yield item diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index 71eb7428e06..44f99b6510f 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -157,6 +157,23 @@ def patch_config( return config +def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig: + base: RunnableConfig = {} + for config in (c for c in configs if c is not None): + for key in config: + if key == "metadata": + base[key] = {**base.get(key, {}), **config[key]} # type: ignore + elif key == "tags": + base[key] = list(set(base.get(key, []) + config[key])) # type: ignore + elif key == "configurable": + base[key] = {**base.get(key, {}), **config[key]} # type: ignore + else: + # Even though the keys aren't literals this is correct + # because both dicts are same type + base[key] = config[key] or base.get(key) # type: ignore + return base + + def call_func_with_variable_args( func: Union[ Callable[[Input], Output],