mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 05:13:46 +00:00
core[patch]:Fix Incorrect listeners parameters for Runnable.with_listeners() and .map() (#20661)
- **Issue:** fix #20509 - @baskaryan, @eyurtsev 
This commit is contained in:
parent
b0f5a47f25
commit
a156aace2b
@ -4266,9 +4266,10 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> List[Output]:
|
||||
return self.bound.batch(
|
||||
inputs, patch_config(config, callbacks=run_manager.get_child()), **kwargs
|
||||
)
|
||||
configs = [
|
||||
patch_config(config, callbacks=run_manager.get_child()) for _ in inputs
|
||||
]
|
||||
return self.bound.batch(inputs, configs, **kwargs)
|
||||
|
||||
def invoke(
|
||||
self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
@ -4282,9 +4283,10 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> List[Output]:
|
||||
return await self.bound.abatch(
|
||||
inputs, patch_config(config, callbacks=run_manager.get_child()), **kwargs
|
||||
)
|
||||
configs = [
|
||||
patch_config(config, callbacks=run_manager.get_child()) for _ in inputs
|
||||
]
|
||||
return await self.bound.abatch(inputs, configs, **kwargs)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
|
@ -5508,3 +5508,64 @@ async def test_passthrough_atransform_with_dicts() -> None:
|
||||
|
||||
chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())]
|
||||
assert chunks == [{"foo": "a"}, {"foo": "n"}]
|
||||
|
||||
|
||||
def test_listeners() -> None:
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
def fake_chain(inputs: dict) -> dict:
|
||||
return {**inputs, "key": "extra"}
|
||||
|
||||
shared_state = {}
|
||||
value1 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
|
||||
value2 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
|
||||
|
||||
def on_start(run: Run) -> None:
|
||||
shared_state[run.id] = {"inputs": run.inputs}
|
||||
|
||||
def on_end(run: Run) -> None:
|
||||
shared_state[run.id]["outputs"] = run.inputs
|
||||
|
||||
chain = (
|
||||
RunnableLambda(fake_chain)
|
||||
.with_listeners(on_end=on_end, on_start=on_start)
|
||||
.map()
|
||||
)
|
||||
|
||||
data = [{"name": "one"}, {"name": "two"}]
|
||||
chain.invoke(data, config={"max_concurrency": 1})
|
||||
assert len(shared_state) == 2
|
||||
assert value1 in shared_state.values(), "Value not found in the dictionary."
|
||||
assert value2 in shared_state.values(), "Value not found in the dictionary."
|
||||
|
||||
|
||||
async def test_listeners_async() -> None:
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
def fake_chain(inputs: dict) -> dict:
|
||||
return {**inputs, "key": "extra"}
|
||||
|
||||
shared_state = {}
|
||||
value1 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
|
||||
value2 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
|
||||
|
||||
def on_start(run: Run) -> None:
|
||||
shared_state[run.id] = {"inputs": run.inputs}
|
||||
|
||||
def on_end(run: Run) -> None:
|
||||
shared_state[run.id]["outputs"] = run.inputs
|
||||
|
||||
chain: Runnable = (
|
||||
RunnableLambda(fake_chain)
|
||||
.with_listeners(on_end=on_end, on_start=on_start)
|
||||
.map()
|
||||
)
|
||||
|
||||
data = [{"name": "one"}, {"name": "two"}]
|
||||
await chain.ainvoke(data, config={"max_concurrency": 1})
|
||||
|
||||
assert len(shared_state) == 2
|
||||
assert value1 in shared_state.values(), "Value not found in the dictionary."
|
||||
assert value2 in shared_state.values(), "Value not found in the dictionary."
|
||||
|
Loading…
Reference in New Issue
Block a user