mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-01 10:54:15 +00:00
Export merge_configs function
This commit is contained in:
parent
12596b9a9b
commit
b0d5882fe1
@ -54,6 +54,7 @@ from langchain.schema.runnable.config import (
|
||||
get_callback_manager_for_config,
|
||||
get_config_list,
|
||||
get_executor_for_config,
|
||||
merge_configs,
|
||||
patch_config,
|
||||
)
|
||||
from langchain.schema.runnable.utils import (
|
||||
@ -564,7 +565,12 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
Bind config to a Runnable, returning a new Runnable.
|
||||
"""
|
||||
return RunnableBinding(
|
||||
bound=self, config={**(config or {}), **kwargs}, kwargs={}
|
||||
bound=self,
|
||||
config=cast(
|
||||
RunnableConfig,
|
||||
{**(config or {}), **kwargs},
|
||||
), # type: ignore[misc]
|
||||
kwargs={},
|
||||
)
|
||||
|
||||
def with_retry(
|
||||
@ -2291,7 +2297,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
|
||||
kwargs: Mapping[str, Any]
|
||||
|
||||
config: Mapping[str, Any] = Field(default_factory=dict)
|
||||
config: RunnableConfig = Field(default_factory=dict)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@ -2301,7 +2307,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
*,
|
||||
bound: Runnable[Input, Output],
|
||||
kwargs: Mapping[str, Any],
|
||||
config: Optional[Mapping[str, Any]] = None,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**other_kwargs: Any,
|
||||
) -> None:
|
||||
config = config or {}
|
||||
@ -2346,22 +2352,6 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
return cls.__module__.split(".")[:-1]
|
||||
|
||||
def _merge_config(self, config: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
copy = cast(RunnableConfig, dict(self.config))
|
||||
if config:
|
||||
for key in config:
|
||||
if key == "metadata":
|
||||
copy[key] = {**copy.get(key, {}), **config[key]} # type: ignore
|
||||
elif key == "tags":
|
||||
copy[key] = (copy.get(key) or []) + config[key] # type: ignore
|
||||
elif key == "configurable":
|
||||
copy[key] = {**copy.get(key, {}), **config[key]} # type: ignore
|
||||
else:
|
||||
# Even though the keys aren't literals this is correct
|
||||
# because both dicts are same type
|
||||
copy[key] = config[key] or copy.get(key) # type: ignore
|
||||
return copy
|
||||
|
||||
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||
return self.__class__(
|
||||
bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs}
|
||||
@ -2376,7 +2366,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
return self.__class__(
|
||||
bound=self.bound,
|
||||
kwargs=self.kwargs,
|
||||
config={**self.config, **(config or {}), **kwargs},
|
||||
config=cast(RunnableConfig, {**self.config, **(config or {}), **kwargs}),
|
||||
)
|
||||
|
||||
def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||
@ -2394,7 +2384,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
) -> Output:
|
||||
return self.bound.invoke(
|
||||
input,
|
||||
self._merge_config(config),
|
||||
merge_configs(self.config, config),
|
||||
**{**self.kwargs, **kwargs},
|
||||
)
|
||||
|
||||
@ -2406,7 +2396,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
) -> Output:
|
||||
return await self.bound.ainvoke(
|
||||
input,
|
||||
self._merge_config(config),
|
||||
merge_configs(self.config, config),
|
||||
**{**self.kwargs, **kwargs},
|
||||
)
|
||||
|
||||
@ -2420,11 +2410,12 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
) -> List[Output]:
|
||||
if isinstance(config, list):
|
||||
configs = cast(
|
||||
List[RunnableConfig], [self._merge_config(conf) for conf in config]
|
||||
List[RunnableConfig],
|
||||
[merge_configs(self.config, conf) for conf in config],
|
||||
)
|
||||
else:
|
||||
configs = [
|
||||
patch_config(self._merge_config(config), copy_locals=True)
|
||||
patch_config(merge_configs(self.config, config), copy_locals=True)
|
||||
for _ in range(len(inputs))
|
||||
]
|
||||
return self.bound.batch(
|
||||
@ -2444,11 +2435,12 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
) -> List[Output]:
|
||||
if isinstance(config, list):
|
||||
configs = cast(
|
||||
List[RunnableConfig], [self._merge_config(conf) for conf in config]
|
||||
List[RunnableConfig],
|
||||
[merge_configs(self.config, conf) for conf in config],
|
||||
)
|
||||
else:
|
||||
configs = [
|
||||
patch_config(self._merge_config(config), copy_locals=True)
|
||||
patch_config(merge_configs(self.config, config), copy_locals=True)
|
||||
for _ in range(len(inputs))
|
||||
]
|
||||
return await self.bound.abatch(
|
||||
@ -2466,7 +2458,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
) -> Iterator[Output]:
|
||||
yield from self.bound.stream(
|
||||
input,
|
||||
self._merge_config(config),
|
||||
merge_configs(self.config, config),
|
||||
**{**self.kwargs, **kwargs},
|
||||
)
|
||||
|
||||
@ -2478,7 +2470,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
) -> AsyncIterator[Output]:
|
||||
async for item in self.bound.astream(
|
||||
input,
|
||||
self._merge_config(config),
|
||||
merge_configs(self.config, config),
|
||||
**{**self.kwargs, **kwargs},
|
||||
):
|
||||
yield item
|
||||
@ -2491,7 +2483,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
) -> Iterator[Output]:
|
||||
yield from self.bound.transform(
|
||||
input,
|
||||
self._merge_config(config),
|
||||
merge_configs(self.config, config),
|
||||
**{**self.kwargs, **kwargs},
|
||||
)
|
||||
|
||||
@ -2503,7 +2495,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
) -> AsyncIterator[Output]:
|
||||
async for item in self.bound.atransform(
|
||||
input,
|
||||
self._merge_config(config),
|
||||
merge_configs(self.config, config),
|
||||
**{**self.kwargs, **kwargs},
|
||||
):
|
||||
yield item
|
||||
|
@ -157,6 +157,23 @@ def patch_config(
|
||||
return config
|
||||
|
||||
|
||||
def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
base: RunnableConfig = {}
|
||||
for config in (c for c in configs if c is not None):
|
||||
for key in config:
|
||||
if key == "metadata":
|
||||
base[key] = {**base.get(key, {}), **config[key]} # type: ignore
|
||||
elif key == "tags":
|
||||
base[key] = list(set(base.get(key, []) + config[key])) # type: ignore
|
||||
elif key == "configurable":
|
||||
base[key] = {**base.get(key, {}), **config[key]} # type: ignore
|
||||
else:
|
||||
# Even though the keys aren't literals this is correct
|
||||
# because both dicts are same type
|
||||
base[key] = config[key] or base.get(key) # type: ignore
|
||||
return base
|
||||
|
||||
|
||||
def call_func_with_variable_args(
|
||||
func: Union[
|
||||
Callable[[Input], Output],
|
||||
|
Loading…
Reference in New Issue
Block a user