mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-02 03:15:11 +00:00
Export merge_configs function
This commit is contained in:
parent
12596b9a9b
commit
b0d5882fe1
@ -54,6 +54,7 @@ from langchain.schema.runnable.config import (
|
|||||||
get_callback_manager_for_config,
|
get_callback_manager_for_config,
|
||||||
get_config_list,
|
get_config_list,
|
||||||
get_executor_for_config,
|
get_executor_for_config,
|
||||||
|
merge_configs,
|
||||||
patch_config,
|
patch_config,
|
||||||
)
|
)
|
||||||
from langchain.schema.runnable.utils import (
|
from langchain.schema.runnable.utils import (
|
||||||
@ -564,7 +565,12 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
Bind config to a Runnable, returning a new Runnable.
|
Bind config to a Runnable, returning a new Runnable.
|
||||||
"""
|
"""
|
||||||
return RunnableBinding(
|
return RunnableBinding(
|
||||||
bound=self, config={**(config or {}), **kwargs}, kwargs={}
|
bound=self,
|
||||||
|
config=cast(
|
||||||
|
RunnableConfig,
|
||||||
|
{**(config or {}), **kwargs},
|
||||||
|
), # type: ignore[misc]
|
||||||
|
kwargs={},
|
||||||
)
|
)
|
||||||
|
|
||||||
def with_retry(
|
def with_retry(
|
||||||
@ -2291,7 +2297,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
|
|
||||||
kwargs: Mapping[str, Any]
|
kwargs: Mapping[str, Any]
|
||||||
|
|
||||||
config: Mapping[str, Any] = Field(default_factory=dict)
|
config: RunnableConfig = Field(default_factory=dict)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
@ -2301,7 +2307,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
*,
|
*,
|
||||||
bound: Runnable[Input, Output],
|
bound: Runnable[Input, Output],
|
||||||
kwargs: Mapping[str, Any],
|
kwargs: Mapping[str, Any],
|
||||||
config: Optional[Mapping[str, Any]] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**other_kwargs: Any,
|
**other_kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
config = config or {}
|
config = config or {}
|
||||||
@ -2346,22 +2352,6 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> List[str]:
|
||||||
return cls.__module__.split(".")[:-1]
|
return cls.__module__.split(".")[:-1]
|
||||||
|
|
||||||
def _merge_config(self, config: Optional[RunnableConfig]) -> RunnableConfig:
|
|
||||||
copy = cast(RunnableConfig, dict(self.config))
|
|
||||||
if config:
|
|
||||||
for key in config:
|
|
||||||
if key == "metadata":
|
|
||||||
copy[key] = {**copy.get(key, {}), **config[key]} # type: ignore
|
|
||||||
elif key == "tags":
|
|
||||||
copy[key] = (copy.get(key) or []) + config[key] # type: ignore
|
|
||||||
elif key == "configurable":
|
|
||||||
copy[key] = {**copy.get(key, {}), **config[key]} # type: ignore
|
|
||||||
else:
|
|
||||||
# Even though the keys aren't literals this is correct
|
|
||||||
# because both dicts are same type
|
|
||||||
copy[key] = config[key] or copy.get(key) # type: ignore
|
|
||||||
return copy
|
|
||||||
|
|
||||||
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs}
|
bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs}
|
||||||
@ -2376,7 +2366,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
return self.__class__(
|
return self.__class__(
|
||||||
bound=self.bound,
|
bound=self.bound,
|
||||||
kwargs=self.kwargs,
|
kwargs=self.kwargs,
|
||||||
config={**self.config, **(config or {}), **kwargs},
|
config=cast(RunnableConfig, {**self.config, **(config or {}), **kwargs}),
|
||||||
)
|
)
|
||||||
|
|
||||||
def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]:
|
def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||||
@ -2394,7 +2384,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
) -> Output:
|
) -> Output:
|
||||||
return self.bound.invoke(
|
return self.bound.invoke(
|
||||||
input,
|
input,
|
||||||
self._merge_config(config),
|
merge_configs(self.config, config),
|
||||||
**{**self.kwargs, **kwargs},
|
**{**self.kwargs, **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -2406,7 +2396,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
) -> Output:
|
) -> Output:
|
||||||
return await self.bound.ainvoke(
|
return await self.bound.ainvoke(
|
||||||
input,
|
input,
|
||||||
self._merge_config(config),
|
merge_configs(self.config, config),
|
||||||
**{**self.kwargs, **kwargs},
|
**{**self.kwargs, **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -2420,11 +2410,12 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
if isinstance(config, list):
|
if isinstance(config, list):
|
||||||
configs = cast(
|
configs = cast(
|
||||||
List[RunnableConfig], [self._merge_config(conf) for conf in config]
|
List[RunnableConfig],
|
||||||
|
[merge_configs(self.config, conf) for conf in config],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
configs = [
|
configs = [
|
||||||
patch_config(self._merge_config(config), copy_locals=True)
|
patch_config(merge_configs(self.config, config), copy_locals=True)
|
||||||
for _ in range(len(inputs))
|
for _ in range(len(inputs))
|
||||||
]
|
]
|
||||||
return self.bound.batch(
|
return self.bound.batch(
|
||||||
@ -2444,11 +2435,12 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
if isinstance(config, list):
|
if isinstance(config, list):
|
||||||
configs = cast(
|
configs = cast(
|
||||||
List[RunnableConfig], [self._merge_config(conf) for conf in config]
|
List[RunnableConfig],
|
||||||
|
[merge_configs(self.config, conf) for conf in config],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
configs = [
|
configs = [
|
||||||
patch_config(self._merge_config(config), copy_locals=True)
|
patch_config(merge_configs(self.config, config), copy_locals=True)
|
||||||
for _ in range(len(inputs))
|
for _ in range(len(inputs))
|
||||||
]
|
]
|
||||||
return await self.bound.abatch(
|
return await self.bound.abatch(
|
||||||
@ -2466,7 +2458,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
yield from self.bound.stream(
|
yield from self.bound.stream(
|
||||||
input,
|
input,
|
||||||
self._merge_config(config),
|
merge_configs(self.config, config),
|
||||||
**{**self.kwargs, **kwargs},
|
**{**self.kwargs, **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -2478,7 +2470,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
async for item in self.bound.astream(
|
async for item in self.bound.astream(
|
||||||
input,
|
input,
|
||||||
self._merge_config(config),
|
merge_configs(self.config, config),
|
||||||
**{**self.kwargs, **kwargs},
|
**{**self.kwargs, **kwargs},
|
||||||
):
|
):
|
||||||
yield item
|
yield item
|
||||||
@ -2491,7 +2483,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
yield from self.bound.transform(
|
yield from self.bound.transform(
|
||||||
input,
|
input,
|
||||||
self._merge_config(config),
|
merge_configs(self.config, config),
|
||||||
**{**self.kwargs, **kwargs},
|
**{**self.kwargs, **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -2503,7 +2495,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
async for item in self.bound.atransform(
|
async for item in self.bound.atransform(
|
||||||
input,
|
input,
|
||||||
self._merge_config(config),
|
merge_configs(self.config, config),
|
||||||
**{**self.kwargs, **kwargs},
|
**{**self.kwargs, **kwargs},
|
||||||
):
|
):
|
||||||
yield item
|
yield item
|
||||||
|
@ -157,6 +157,23 @@ def patch_config(
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
|
||||||
|
base: RunnableConfig = {}
|
||||||
|
for config in (c for c in configs if c is not None):
|
||||||
|
for key in config:
|
||||||
|
if key == "metadata":
|
||||||
|
base[key] = {**base.get(key, {}), **config[key]} # type: ignore
|
||||||
|
elif key == "tags":
|
||||||
|
base[key] = list(set(base.get(key, []) + config[key])) # type: ignore
|
||||||
|
elif key == "configurable":
|
||||||
|
base[key] = {**base.get(key, {}), **config[key]} # type: ignore
|
||||||
|
else:
|
||||||
|
# Even though the keys aren't literals this is correct
|
||||||
|
# because both dicts are same type
|
||||||
|
base[key] = config[key] or base.get(key) # type: ignore
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
def call_func_with_variable_args(
|
def call_func_with_variable_args(
|
||||||
func: Union[
|
func: Union[
|
||||||
Callable[[Input], Output],
|
Callable[[Input], Output],
|
||||||
|
Loading…
Reference in New Issue
Block a user