Adjust merge logic

This commit is contained in:
Nuno Campos 2023-08-24 16:49:14 +02:00
parent f95bd0bcd9
commit 5426712311
2 changed files with 40 additions and 18 deletions

View File

@ -1527,6 +1527,13 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
def lc_namespace(self) -> List[str]: def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1] return self.__class__.__module__.split(".")[:-1]
def _merge_config(self, config: Optional[RunnableConfig]) -> RunnableConfig:
copy = cast(RunnableConfig, dict(self.config))
if config:
for key in config:
copy[key] = config[key] or copy.get(key)
return copy
def bind(self, **kwargs: Any) -> Runnable[Input, Output]: def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
return self.__class__( return self.__class__(
bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs} bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs}
@ -1552,7 +1559,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
) -> Output: ) -> Output:
return self.bound.invoke( return self.bound.invoke(
input, input,
cast(RunnableConfig, {**self.config, **(config or {})}), self._merge_config(config),
**{**self.kwargs, **kwargs}, **{**self.kwargs, **kwargs},
) )
@ -1564,7 +1571,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
) -> Output: ) -> Output:
return await self.bound.ainvoke( return await self.bound.ainvoke(
input, input,
cast(RunnableConfig, {**self.config, **(config or {})}), self._merge_config(config),
**{**self.kwargs, **kwargs}, **{**self.kwargs, **kwargs},
) )
@ -1576,13 +1583,10 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
) -> List[Output]: ) -> List[Output]:
configs = cast( configs = cast(
List[RunnableConfig], List[RunnableConfig],
[{**self.config, **(conf or {})} for conf in config] [self._merge_config(conf) for conf in config]
if isinstance(config, list) if isinstance(config, list)
else [ else [
patch_config( patch_config(self._merge_config(config), deep_copy_locals=True)
cast(RunnableConfig, {**self.config, **(config or {})}),
deep_copy_locals=True,
)
for _ in range(len(inputs)) for _ in range(len(inputs))
], ],
) )
@ -1596,13 +1600,10 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
) -> List[Output]: ) -> List[Output]:
configs = cast( configs = cast(
List[RunnableConfig], List[RunnableConfig],
[{**self.config, **(conf or {})} for conf in config] [self._merge_config(conf) for conf in config]
if isinstance(config, list) if isinstance(config, list)
else [ else [
patch_config( patch_config(self._merge_config(config), deep_copy_locals=True)
cast(RunnableConfig, {**self.config, **(config or {})}),
deep_copy_locals=True,
)
for _ in range(len(inputs)) for _ in range(len(inputs))
], ],
) )
@ -1616,7 +1617,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
) -> Iterator[Output]: ) -> Iterator[Output]:
yield from self.bound.stream( yield from self.bound.stream(
input, input,
cast(RunnableConfig, {**self.config, **(config or {})}), self._merge_config(config),
**{**self.kwargs, **kwargs}, **{**self.kwargs, **kwargs},
) )
@ -1628,7 +1629,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
async for item in self.bound.astream( async for item in self.bound.astream(
input, input,
cast(RunnableConfig, {**self.config, **(config or {})}), self._merge_config(config),
**{**self.kwargs, **kwargs}, **{**self.kwargs, **kwargs},
): ):
yield item yield item
@ -1641,7 +1642,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
) -> Iterator[Output]: ) -> Iterator[Output]:
yield from self.bound.transform( yield from self.bound.transform(
input, input,
cast(RunnableConfig, {**self.config, **(config or {})}), self._merge_config(config),
**{**self.kwargs, **kwargs}, **{**self.kwargs, **kwargs},
) )
@ -1653,7 +1654,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
async for item in self.bound.atransform( async for item in self.bound.atransform(
input, input,
cast(RunnableConfig, {**self.config, **(config or {})}), self._merge_config(config),
**{**self.kwargs, **kwargs}, **{**self.kwargs, **kwargs},
): ):
yield item yield item

View File

@ -1,6 +1,7 @@
from operator import itemgetter from operator import itemgetter
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from uuid import UUID from uuid import UUID
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
import pytest import pytest
from freezegun import freeze_time from freezegun import freeze_time
@ -123,6 +124,25 @@ async def test_with_config(mocker: MockerFixture) -> None:
] ]
spy.reset_mock() spy.reset_mock()
fake_1 = RunnablePassthrough()
fake_2 = RunnablePassthrough()
spy_seq_step = mocker.spy(fake_1.__class__, "invoke")
sequence = fake_1.with_config(tags=["a-tag"]) | fake_2.with_config(
tags=["b-tag"], max_concurrency=5
)
assert sequence.invoke("hello") == "hello"
assert len(spy_seq_step.call_args_list) == 2
for i, call in enumerate(spy_seq_step.call_args_list):
assert call.args[1] == "hello"
if i == 0:
assert call.args[2].get("tags") == ["a-tag"]
assert call.args[2].get("max_concurrency") is None
else:
assert call.args[2].get("tags") == ["b-tag"]
assert call.args[2].get("max_concurrency") == 5
spy_seq_step.reset_mock()
assert [ assert [
*fake.with_config(tags=["a-tag"]).stream( *fake.with_config(tags=["a-tag"]).stream(
"hello", dict(metadata={"key": "value"}) "hello", dict(metadata={"key": "value"})
@ -161,14 +181,15 @@ async def test_with_config(mocker: MockerFixture) -> None:
assert call.args[1].get("metadata") == {"a": "b"} assert call.args[1].get("metadata") == {"a": "b"}
spy.reset_mock() spy.reset_mock()
handler = ConsoleCallbackHandler()
assert ( assert (
await fake.with_config(metadata={"a": "b"}).ainvoke( await fake.with_config(metadata={"a": "b"}).ainvoke(
"hello", config={"callbacks": []} "hello", config={"callbacks": [handler]}
) )
== 5 == 5
) )
assert spy.call_args_list == [ assert spy.call_args_list == [
mocker.call("hello", dict(callbacks=[], metadata={"a": "b"})), mocker.call("hello", dict(callbacks=[handler], metadata={"a": "b"})),
] ]
spy.reset_mock() spy.reset_mock()