mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
core[patch]:Fix Incorrect listeners parameters for Runnable.with_listeners() and .map() (#20661)
- **Issue:** fix #20509 - @baskaryan, @eyurtsev 
This commit is contained in:
parent
b0f5a47f25
commit
a156aace2b
@ -4266,9 +4266,10 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
|||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
return self.bound.batch(
|
configs = [
|
||||||
inputs, patch_config(config, callbacks=run_manager.get_child()), **kwargs
|
patch_config(config, callbacks=run_manager.get_child()) for _ in inputs
|
||||||
)
|
]
|
||||||
|
return self.bound.batch(inputs, configs, **kwargs)
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
@ -4282,9 +4283,10 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
|||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
return await self.bound.abatch(
|
configs = [
|
||||||
inputs, patch_config(config, callbacks=run_manager.get_child()), **kwargs
|
patch_config(config, callbacks=run_manager.get_child()) for _ in inputs
|
||||||
)
|
]
|
||||||
|
return await self.bound.abatch(inputs, configs, **kwargs)
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any
|
self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
|
@ -5508,3 +5508,64 @@ async def test_passthrough_atransform_with_dicts() -> None:
|
|||||||
|
|
||||||
chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())]
|
chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())]
|
||||||
assert chunks == [{"foo": "a"}, {"foo": "n"}]
|
assert chunks == [{"foo": "a"}, {"foo": "n"}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_listeners() -> None:
|
||||||
|
from langchain_core.runnables import RunnableLambda
|
||||||
|
from langchain_core.tracers.schemas import Run
|
||||||
|
|
||||||
|
def fake_chain(inputs: dict) -> dict:
|
||||||
|
return {**inputs, "key": "extra"}
|
||||||
|
|
||||||
|
shared_state = {}
|
||||||
|
value1 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
|
||||||
|
value2 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
|
||||||
|
|
||||||
|
def on_start(run: Run) -> None:
|
||||||
|
shared_state[run.id] = {"inputs": run.inputs}
|
||||||
|
|
||||||
|
def on_end(run: Run) -> None:
|
||||||
|
shared_state[run.id]["outputs"] = run.inputs
|
||||||
|
|
||||||
|
chain = (
|
||||||
|
RunnableLambda(fake_chain)
|
||||||
|
.with_listeners(on_end=on_end, on_start=on_start)
|
||||||
|
.map()
|
||||||
|
)
|
||||||
|
|
||||||
|
data = [{"name": "one"}, {"name": "two"}]
|
||||||
|
chain.invoke(data, config={"max_concurrency": 1})
|
||||||
|
assert len(shared_state) == 2
|
||||||
|
assert value1 in shared_state.values(), "Value not found in the dictionary."
|
||||||
|
assert value2 in shared_state.values(), "Value not found in the dictionary."
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listeners_async() -> None:
|
||||||
|
from langchain_core.runnables import RunnableLambda
|
||||||
|
from langchain_core.tracers.schemas import Run
|
||||||
|
|
||||||
|
def fake_chain(inputs: dict) -> dict:
|
||||||
|
return {**inputs, "key": "extra"}
|
||||||
|
|
||||||
|
shared_state = {}
|
||||||
|
value1 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
|
||||||
|
value2 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
|
||||||
|
|
||||||
|
def on_start(run: Run) -> None:
|
||||||
|
shared_state[run.id] = {"inputs": run.inputs}
|
||||||
|
|
||||||
|
def on_end(run: Run) -> None:
|
||||||
|
shared_state[run.id]["outputs"] = run.inputs
|
||||||
|
|
||||||
|
chain: Runnable = (
|
||||||
|
RunnableLambda(fake_chain)
|
||||||
|
.with_listeners(on_end=on_end, on_start=on_start)
|
||||||
|
.map()
|
||||||
|
)
|
||||||
|
|
||||||
|
data = [{"name": "one"}, {"name": "two"}]
|
||||||
|
await chain.ainvoke(data, config={"max_concurrency": 1})
|
||||||
|
|
||||||
|
assert len(shared_state) == 2
|
||||||
|
assert value1 in shared_state.values(), "Value not found in the dictionary."
|
||||||
|
assert value2 in shared_state.values(), "Value not found in the dictionary."
|
||||||
|
Loading…
Reference in New Issue
Block a user