mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-28 14:05:02 +00:00
Adjust merge logic
This commit is contained in:
parent
f95bd0bcd9
commit
5426712311
@ -1527,6 +1527,13 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
def lc_namespace(self) -> List[str]:
|
def lc_namespace(self) -> List[str]:
|
||||||
return self.__class__.__module__.split(".")[:-1]
|
return self.__class__.__module__.split(".")[:-1]
|
||||||
|
|
||||||
|
def _merge_config(self, config: Optional[RunnableConfig]) -> RunnableConfig:
|
||||||
|
copy = cast(RunnableConfig, dict(self.config))
|
||||||
|
if config:
|
||||||
|
for key in config:
|
||||||
|
copy[key] = config[key] or copy.get(key)
|
||||||
|
return copy
|
||||||
|
|
||||||
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs}
|
bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs}
|
||||||
@ -1552,7 +1559,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
) -> Output:
|
) -> Output:
|
||||||
return self.bound.invoke(
|
return self.bound.invoke(
|
||||||
input,
|
input,
|
||||||
cast(RunnableConfig, {**self.config, **(config or {})}),
|
self._merge_config(config),
|
||||||
**{**self.kwargs, **kwargs},
|
**{**self.kwargs, **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1564,7 +1571,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
) -> Output:
|
) -> Output:
|
||||||
return await self.bound.ainvoke(
|
return await self.bound.ainvoke(
|
||||||
input,
|
input,
|
||||||
cast(RunnableConfig, {**self.config, **(config or {})}),
|
self._merge_config(config),
|
||||||
**{**self.kwargs, **kwargs},
|
**{**self.kwargs, **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1576,13 +1583,10 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
configs = cast(
|
configs = cast(
|
||||||
List[RunnableConfig],
|
List[RunnableConfig],
|
||||||
[{**self.config, **(conf or {})} for conf in config]
|
[self._merge_config(conf) for conf in config]
|
||||||
if isinstance(config, list)
|
if isinstance(config, list)
|
||||||
else [
|
else [
|
||||||
patch_config(
|
patch_config(self._merge_config(config), deep_copy_locals=True)
|
||||||
cast(RunnableConfig, {**self.config, **(config or {})}),
|
|
||||||
deep_copy_locals=True,
|
|
||||||
)
|
|
||||||
for _ in range(len(inputs))
|
for _ in range(len(inputs))
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -1596,13 +1600,10 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
configs = cast(
|
configs = cast(
|
||||||
List[RunnableConfig],
|
List[RunnableConfig],
|
||||||
[{**self.config, **(conf or {})} for conf in config]
|
[self._merge_config(conf) for conf in config]
|
||||||
if isinstance(config, list)
|
if isinstance(config, list)
|
||||||
else [
|
else [
|
||||||
patch_config(
|
patch_config(self._merge_config(config), deep_copy_locals=True)
|
||||||
cast(RunnableConfig, {**self.config, **(config or {})}),
|
|
||||||
deep_copy_locals=True,
|
|
||||||
)
|
|
||||||
for _ in range(len(inputs))
|
for _ in range(len(inputs))
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -1616,7 +1617,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
yield from self.bound.stream(
|
yield from self.bound.stream(
|
||||||
input,
|
input,
|
||||||
cast(RunnableConfig, {**self.config, **(config or {})}),
|
self._merge_config(config),
|
||||||
**{**self.kwargs, **kwargs},
|
**{**self.kwargs, **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1628,7 +1629,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
async for item in self.bound.astream(
|
async for item in self.bound.astream(
|
||||||
input,
|
input,
|
||||||
cast(RunnableConfig, {**self.config, **(config or {})}),
|
self._merge_config(config),
|
||||||
**{**self.kwargs, **kwargs},
|
**{**self.kwargs, **kwargs},
|
||||||
):
|
):
|
||||||
yield item
|
yield item
|
||||||
@ -1641,7 +1642,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
yield from self.bound.transform(
|
yield from self.bound.transform(
|
||||||
input,
|
input,
|
||||||
cast(RunnableConfig, {**self.config, **(config or {})}),
|
self._merge_config(config),
|
||||||
**{**self.kwargs, **kwargs},
|
**{**self.kwargs, **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1653,7 +1654,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
async for item in self.bound.atransform(
|
async for item in self.bound.atransform(
|
||||||
input,
|
input,
|
||||||
cast(RunnableConfig, {**self.config, **(config or {})}),
|
self._merge_config(config),
|
||||||
**{**self.kwargs, **kwargs},
|
**{**self.kwargs, **kwargs},
|
||||||
):
|
):
|
||||||
yield item
|
yield item
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
@ -123,6 +124,25 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
|||||||
]
|
]
|
||||||
spy.reset_mock()
|
spy.reset_mock()
|
||||||
|
|
||||||
|
fake_1 = RunnablePassthrough()
|
||||||
|
fake_2 = RunnablePassthrough()
|
||||||
|
spy_seq_step = mocker.spy(fake_1.__class__, "invoke")
|
||||||
|
|
||||||
|
sequence = fake_1.with_config(tags=["a-tag"]) | fake_2.with_config(
|
||||||
|
tags=["b-tag"], max_concurrency=5
|
||||||
|
)
|
||||||
|
assert sequence.invoke("hello") == "hello"
|
||||||
|
assert len(spy_seq_step.call_args_list) == 2
|
||||||
|
for i, call in enumerate(spy_seq_step.call_args_list):
|
||||||
|
assert call.args[1] == "hello"
|
||||||
|
if i == 0:
|
||||||
|
assert call.args[2].get("tags") == ["a-tag"]
|
||||||
|
assert call.args[2].get("max_concurrency") is None
|
||||||
|
else:
|
||||||
|
assert call.args[2].get("tags") == ["b-tag"]
|
||||||
|
assert call.args[2].get("max_concurrency") == 5
|
||||||
|
spy_seq_step.reset_mock()
|
||||||
|
|
||||||
assert [
|
assert [
|
||||||
*fake.with_config(tags=["a-tag"]).stream(
|
*fake.with_config(tags=["a-tag"]).stream(
|
||||||
"hello", dict(metadata={"key": "value"})
|
"hello", dict(metadata={"key": "value"})
|
||||||
@ -161,14 +181,15 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
|||||||
assert call.args[1].get("metadata") == {"a": "b"}
|
assert call.args[1].get("metadata") == {"a": "b"}
|
||||||
spy.reset_mock()
|
spy.reset_mock()
|
||||||
|
|
||||||
|
handler = ConsoleCallbackHandler()
|
||||||
assert (
|
assert (
|
||||||
await fake.with_config(metadata={"a": "b"}).ainvoke(
|
await fake.with_config(metadata={"a": "b"}).ainvoke(
|
||||||
"hello", config={"callbacks": []}
|
"hello", config={"callbacks": [handler]}
|
||||||
)
|
)
|
||||||
== 5
|
== 5
|
||||||
)
|
)
|
||||||
assert spy.call_args_list == [
|
assert spy.call_args_list == [
|
||||||
mocker.call("hello", dict(callbacks=[], metadata={"a": "b"})),
|
mocker.call("hello", dict(callbacks=[handler], metadata={"a": "b"})),
|
||||||
]
|
]
|
||||||
spy.reset_mock()
|
spy.reset_mock()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user