Export merge_configs function

This commit is contained in:
Nuno Campos 2023-10-17 13:22:07 +01:00
parent 12596b9a9b
commit b0d5882fe1
2 changed files with 39 additions and 30 deletions

View File

@ -54,6 +54,7 @@ from langchain.schema.runnable.config import (
get_callback_manager_for_config, get_callback_manager_for_config,
get_config_list, get_config_list,
get_executor_for_config, get_executor_for_config,
merge_configs,
patch_config, patch_config,
) )
from langchain.schema.runnable.utils import ( from langchain.schema.runnable.utils import (
@ -564,7 +565,12 @@ class Runnable(Generic[Input, Output], ABC):
Bind config to a Runnable, returning a new Runnable. Bind config to a Runnable, returning a new Runnable.
""" """
return RunnableBinding( return RunnableBinding(
bound=self, config={**(config or {}), **kwargs}, kwargs={} bound=self,
config=cast(
RunnableConfig,
{**(config or {}), **kwargs},
), # type: ignore[misc]
kwargs={},
) )
def with_retry( def with_retry(
@ -2291,7 +2297,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
kwargs: Mapping[str, Any] kwargs: Mapping[str, Any]
config: Mapping[str, Any] = Field(default_factory=dict) config: RunnableConfig = Field(default_factory=dict)
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@ -2301,7 +2307,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
*, *,
bound: Runnable[Input, Output], bound: Runnable[Input, Output],
kwargs: Mapping[str, Any], kwargs: Mapping[str, Any],
config: Optional[Mapping[str, Any]] = None, config: Optional[RunnableConfig] = None,
**other_kwargs: Any, **other_kwargs: Any,
) -> None: ) -> None:
config = config or {} config = config or {}
@ -2346,22 +2352,6 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1] return cls.__module__.split(".")[:-1]
def _merge_config(self, config: Optional[RunnableConfig]) -> RunnableConfig:
copy = cast(RunnableConfig, dict(self.config))
if config:
for key in config:
if key == "metadata":
copy[key] = {**copy.get(key, {}), **config[key]} # type: ignore
elif key == "tags":
copy[key] = (copy.get(key) or []) + config[key] # type: ignore
elif key == "configurable":
copy[key] = {**copy.get(key, {}), **config[key]} # type: ignore
else:
# Even though the keys aren't literals this is correct
# because both dicts are same type
copy[key] = config[key] or copy.get(key) # type: ignore
return copy
def bind(self, **kwargs: Any) -> Runnable[Input, Output]: def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
return self.__class__( return self.__class__(
bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs} bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs}
@ -2376,7 +2366,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
return self.__class__( return self.__class__(
bound=self.bound, bound=self.bound,
kwargs=self.kwargs, kwargs=self.kwargs,
config={**self.config, **(config or {}), **kwargs}, config=cast(RunnableConfig, {**self.config, **(config or {}), **kwargs}),
) )
def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]: def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]:
@ -2394,7 +2384,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> Output: ) -> Output:
return self.bound.invoke( return self.bound.invoke(
input, input,
self._merge_config(config), merge_configs(self.config, config),
**{**self.kwargs, **kwargs}, **{**self.kwargs, **kwargs},
) )
@ -2406,7 +2396,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> Output: ) -> Output:
return await self.bound.ainvoke( return await self.bound.ainvoke(
input, input,
self._merge_config(config), merge_configs(self.config, config),
**{**self.kwargs, **kwargs}, **{**self.kwargs, **kwargs},
) )
@ -2420,11 +2410,12 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> List[Output]: ) -> List[Output]:
if isinstance(config, list): if isinstance(config, list):
configs = cast( configs = cast(
List[RunnableConfig], [self._merge_config(conf) for conf in config] List[RunnableConfig],
[merge_configs(self.config, conf) for conf in config],
) )
else: else:
configs = [ configs = [
patch_config(self._merge_config(config), copy_locals=True) patch_config(merge_configs(self.config, config), copy_locals=True)
for _ in range(len(inputs)) for _ in range(len(inputs))
] ]
return self.bound.batch( return self.bound.batch(
@ -2444,11 +2435,12 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> List[Output]: ) -> List[Output]:
if isinstance(config, list): if isinstance(config, list):
configs = cast( configs = cast(
List[RunnableConfig], [self._merge_config(conf) for conf in config] List[RunnableConfig],
[merge_configs(self.config, conf) for conf in config],
) )
else: else:
configs = [ configs = [
patch_config(self._merge_config(config), copy_locals=True) patch_config(merge_configs(self.config, config), copy_locals=True)
for _ in range(len(inputs)) for _ in range(len(inputs))
] ]
return await self.bound.abatch( return await self.bound.abatch(
@ -2466,7 +2458,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> Iterator[Output]: ) -> Iterator[Output]:
yield from self.bound.stream( yield from self.bound.stream(
input, input,
self._merge_config(config), merge_configs(self.config, config),
**{**self.kwargs, **kwargs}, **{**self.kwargs, **kwargs},
) )
@ -2478,7 +2470,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
async for item in self.bound.astream( async for item in self.bound.astream(
input, input,
self._merge_config(config), merge_configs(self.config, config),
**{**self.kwargs, **kwargs}, **{**self.kwargs, **kwargs},
): ):
yield item yield item
@ -2491,7 +2483,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> Iterator[Output]: ) -> Iterator[Output]:
yield from self.bound.transform( yield from self.bound.transform(
input, input,
self._merge_config(config), merge_configs(self.config, config),
**{**self.kwargs, **kwargs}, **{**self.kwargs, **kwargs},
) )
@ -2503,7 +2495,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
async for item in self.bound.atransform( async for item in self.bound.atransform(
input, input,
self._merge_config(config), merge_configs(self.config, config),
**{**self.kwargs, **kwargs}, **{**self.kwargs, **kwargs},
): ):
yield item yield item

View File

@ -157,6 +157,23 @@ def patch_config(
return config return config
def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
base: RunnableConfig = {}
for config in (c for c in configs if c is not None):
for key in config:
if key == "metadata":
base[key] = {**base.get(key, {}), **config[key]} # type: ignore
elif key == "tags":
base[key] = list(set(base.get(key, []) + config[key])) # type: ignore
elif key == "configurable":
base[key] = {**base.get(key, {}), **config[key]} # type: ignore
else:
# Even though the keys aren't literals this is correct
# because both dicts are same type
base[key] = config[key] or base.get(key) # type: ignore
return base
def call_func_with_variable_args( def call_func_with_variable_args(
func: Union[ func: Union[
Callable[[Input], Output], Callable[[Input], Output],