Export merge_configs function (#11916)

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos 2023-10-17 15:36:11 +01:00 committed by GitHub
commit 2a8ded6c8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 30 deletions

View File

@ -54,6 +54,7 @@ from langchain.schema.runnable.config import (
get_callback_manager_for_config,
get_config_list,
get_executor_for_config,
merge_configs,
patch_config,
)
from langchain.schema.runnable.utils import (
@ -564,7 +565,12 @@ class Runnable(Generic[Input, Output], ABC):
Bind config to a Runnable, returning a new Runnable.
"""
return RunnableBinding(
bound=self, config={**(config or {}), **kwargs}, kwargs={}
bound=self,
config=cast(
RunnableConfig,
{**(config or {}), **kwargs},
), # type: ignore[misc]
kwargs={},
)
def with_retry(
@ -2291,7 +2297,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
kwargs: Mapping[str, Any]
config: Mapping[str, Any] = Field(default_factory=dict)
config: RunnableConfig = Field(default_factory=dict)
class Config:
arbitrary_types_allowed = True
@ -2301,7 +2307,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
*,
bound: Runnable[Input, Output],
kwargs: Mapping[str, Any],
config: Optional[Mapping[str, Any]] = None,
config: Optional[RunnableConfig] = None,
**other_kwargs: Any,
) -> None:
config = config or {}
@ -2347,22 +2353,6 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
def _merge_config(self, config: Optional[RunnableConfig]) -> RunnableConfig:
copy = cast(RunnableConfig, dict(self.config))
if config:
for key in config:
if key == "metadata":
copy[key] = {**copy.get(key, {}), **config[key]} # type: ignore
elif key == "tags":
copy[key] = (copy.get(key) or []) + config[key] # type: ignore
elif key == "configurable":
copy[key] = {**copy.get(key, {}), **config[key]} # type: ignore
else:
# Even though the keys aren't literals this is correct
# because both dicts are same type
copy[key] = config[key] or copy.get(key) # type: ignore
return copy
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs}
@ -2377,7 +2367,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
return self.__class__(
bound=self.bound,
kwargs=self.kwargs,
config={**self.config, **(config or {}), **kwargs},
config=cast(RunnableConfig, {**self.config, **(config or {}), **kwargs}),
)
def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]:
@ -2395,7 +2385,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> Output:
return self.bound.invoke(
input,
self._merge_config(config),
merge_configs(self.config, config),
**{**self.kwargs, **kwargs},
)
@ -2407,7 +2397,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> Output:
return await self.bound.ainvoke(
input,
self._merge_config(config),
merge_configs(self.config, config),
**{**self.kwargs, **kwargs},
)
@ -2421,11 +2411,12 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> List[Output]:
if isinstance(config, list):
configs = cast(
List[RunnableConfig], [self._merge_config(conf) for conf in config]
List[RunnableConfig],
[merge_configs(self.config, conf) for conf in config],
)
else:
configs = [
patch_config(self._merge_config(config), copy_locals=True)
patch_config(merge_configs(self.config, config), copy_locals=True)
for _ in range(len(inputs))
]
return self.bound.batch(
@ -2445,11 +2436,12 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> List[Output]:
if isinstance(config, list):
configs = cast(
List[RunnableConfig], [self._merge_config(conf) for conf in config]
List[RunnableConfig],
[merge_configs(self.config, conf) for conf in config],
)
else:
configs = [
patch_config(self._merge_config(config), copy_locals=True)
patch_config(merge_configs(self.config, config), copy_locals=True)
for _ in range(len(inputs))
]
return await self.bound.abatch(
@ -2467,7 +2459,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> Iterator[Output]:
yield from self.bound.stream(
input,
self._merge_config(config),
merge_configs(self.config, config),
**{**self.kwargs, **kwargs},
)
@ -2479,7 +2471,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> AsyncIterator[Output]:
async for item in self.bound.astream(
input,
self._merge_config(config),
merge_configs(self.config, config),
**{**self.kwargs, **kwargs},
):
yield item
@ -2492,7 +2484,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> Iterator[Output]:
yield from self.bound.transform(
input,
self._merge_config(config),
merge_configs(self.config, config),
**{**self.kwargs, **kwargs},
)
@ -2504,7 +2496,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
) -> AsyncIterator[Output]:
async for item in self.bound.atransform(
input,
self._merge_config(config),
merge_configs(self.config, config),
**{**self.kwargs, **kwargs},
):
yield item

View File

@ -157,6 +157,31 @@ def patch_config(
return config
def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
base: RunnableConfig = {}
# Even though the keys aren't literals this is correct
# because both dicts are same type
for config in (c for c in configs if c is not None):
for key in config:
if key == "metadata":
base[key] = { # type: ignore
**base.get(key, {}), # type: ignore
**(config.get(key) or {}), # type: ignore
}
elif key == "tags":
base[key] = list( # type: ignore
set(base.get(key, []) + (config.get(key) or [])), # type: ignore
)
elif key == "configurable":
base[key] = { # type: ignore
**base.get(key, {}), # type: ignore
**(config.get(key) or {}), # type: ignore
}
else:
base[key] = config[key] or base.get(key) # type: ignore
return base
def call_func_with_variable_args(
func: Union[
Callable[[Input], Output],