Export merge_configs function

This commit is contained in:
Nuno Campos 2023-10-17 13:22:07 +01:00
parent 12596b9a9b
commit b0d5882fe1
2 changed files with 39 additions and 30 deletions

View File

@ -54,6 +54,7 @@ from langchain.schema.runnable.config import (
get_callback_manager_for_config,
get_config_list,
get_executor_for_config,
merge_configs,
patch_config,
)
from langchain.schema.runnable.utils import (
@ -564,7 +565,12 @@ class Runnable(Generic[Input, Output], ABC):
Bind config to a Runnable, returning a new Runnable.
"""
return RunnableBinding(
bound=self, config={**(config or {}), **kwargs}, kwargs={}
bound=self,
config=cast(
RunnableConfig,
{**(config or {}), **kwargs},
), # type: ignore[misc]
kwargs={},
)
def with_retry(
@ -2291,7 +2297,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
kwargs: Mapping[str, Any]
config: Mapping[str, Any] = Field(default_factory=dict)
config: RunnableConfig = Field(default_factory=dict)
class Config:
arbitrary_types_allowed = True
@ -2301,7 +2307,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
*,
bound: Runnable[Input, Output],
kwargs: Mapping[str, Any],
config: Optional[Mapping[str, Any]] = None,
config: Optional[RunnableConfig] = None,
**other_kwargs: Any,
) -> None:
config = config or {}
@ -2346,22 +2352,6 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
def _merge_config(self, config: Optional[RunnableConfig]) -> RunnableConfig:
copy = cast(RunnableConfig, dict(self.config))
if config:
for key in config:
if key == "metadata":
copy[key] = {**copy.get(key, {}), **config[key]} # type: ignore
elif key == "tags":
copy[key] = (copy.get(key) or []) + config[key] # type: ignore
elif key == "configurable":
copy[key] = {**copy.get(key, {}), **config[key]} # type: ignore
else:
# Even though the keys aren't literals this is correct
# because both dicts are same type
copy[key] = config[key] or copy.get(key) # type: ignore
return copy
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs}
@ -2376,7 +2366,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
return self.__class__(
bound=self.bound,
kwargs=self.kwargs,
config={**self.config, **(config or {}), **kwargs},
config=cast(RunnableConfig, {**self.config, **(config or {}), **kwargs}),
)
def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]:
@ -2394,7 +2384,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> Output:
return self.bound.invoke(
input,
self._merge_config(config),
merge_configs(self.config, config),
**{**self.kwargs, **kwargs},
)
@ -2406,7 +2396,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> Output:
return await self.bound.ainvoke(
input,
self._merge_config(config),
merge_configs(self.config, config),
**{**self.kwargs, **kwargs},
)
@ -2420,11 +2410,12 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> List[Output]:
if isinstance(config, list):
configs = cast(
List[RunnableConfig], [self._merge_config(conf) for conf in config]
List[RunnableConfig],
[merge_configs(self.config, conf) for conf in config],
)
else:
configs = [
patch_config(self._merge_config(config), copy_locals=True)
patch_config(merge_configs(self.config, config), copy_locals=True)
for _ in range(len(inputs))
]
return self.bound.batch(
@ -2444,11 +2435,12 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> List[Output]:
if isinstance(config, list):
configs = cast(
List[RunnableConfig], [self._merge_config(conf) for conf in config]
List[RunnableConfig],
[merge_configs(self.config, conf) for conf in config],
)
else:
configs = [
patch_config(self._merge_config(config), copy_locals=True)
patch_config(merge_configs(self.config, config), copy_locals=True)
for _ in range(len(inputs))
]
return await self.bound.abatch(
@ -2466,7 +2458,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> Iterator[Output]:
yield from self.bound.stream(
input,
self._merge_config(config),
merge_configs(self.config, config),
**{**self.kwargs, **kwargs},
)
@ -2478,7 +2470,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> AsyncIterator[Output]:
async for item in self.bound.astream(
input,
self._merge_config(config),
merge_configs(self.config, config),
**{**self.kwargs, **kwargs},
):
yield item
@ -2491,7 +2483,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> Iterator[Output]:
yield from self.bound.transform(
input,
self._merge_config(config),
merge_configs(self.config, config),
**{**self.kwargs, **kwargs},
)
@ -2503,7 +2495,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> AsyncIterator[Output]:
async for item in self.bound.atransform(
input,
self._merge_config(config),
merge_configs(self.config, config),
**{**self.kwargs, **kwargs},
):
yield item

View File

@ -157,6 +157,23 @@ def patch_config(
return config
def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
base: RunnableConfig = {}
for config in (c for c in configs if c is not None):
for key in config:
if key == "metadata":
base[key] = {**base.get(key, {}), **config[key]} # type: ignore
elif key == "tags":
base[key] = list(set(base.get(key, []) + config[key])) # type: ignore
elif key == "configurable":
base[key] = {**base.get(key, {}), **config[key]} # type: ignore
else:
# Even though the keys aren't literals this is correct
# because both dicts are same type
base[key] = config[key] or base.get(key) # type: ignore
return base
def call_func_with_variable_args(
func: Union[
Callable[[Input], Output],