core[patch]:Fix Incorrect listeners parameters for Runnable.with_listeners() and .map() (#20661)

- **Issue:** fix #20509
-  @baskaryan, @eyurtsev


![image](https://github.com/langchain-ai/langchain/assets/48236177/f799a976-b983-4d8b-b373-64392e1fd6c6)
This commit is contained in:
Guangdong Liu 2024-05-13 23:16:17 +08:00 committed by GitHub
parent b0f5a47f25
commit a156aace2b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 69 additions and 6 deletions

View File

@ -4266,9 +4266,10 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
config: RunnableConfig,
**kwargs: Any,
) -> List[Output]:
return self.bound.batch(
inputs, patch_config(config, callbacks=run_manager.get_child()), **kwargs
)
configs = [
patch_config(config, callbacks=run_manager.get_child()) for _ in inputs
]
return self.bound.batch(inputs, configs, **kwargs)
def invoke(
self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any
@ -4282,9 +4283,10 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
config: RunnableConfig,
**kwargs: Any,
) -> List[Output]:
return await self.bound.abatch(
inputs, patch_config(config, callbacks=run_manager.get_child()), **kwargs
)
configs = [
patch_config(config, callbacks=run_manager.get_child()) for _ in inputs
]
return await self.bound.abatch(inputs, configs, **kwargs)
async def ainvoke(
self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any

View File

@ -5508,3 +5508,64 @@ async def test_passthrough_atransform_with_dicts() -> None:
chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())]
assert chunks == [{"foo": "a"}, {"foo": "n"}]
def test_listeners() -> None:
from langchain_core.runnables import RunnableLambda
from langchain_core.tracers.schemas import Run
def fake_chain(inputs: dict) -> dict:
return {**inputs, "key": "extra"}
shared_state = {}
value1 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
value2 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
def on_start(run: Run) -> None:
shared_state[run.id] = {"inputs": run.inputs}
def on_end(run: Run) -> None:
shared_state[run.id]["outputs"] = run.inputs
chain = (
RunnableLambda(fake_chain)
.with_listeners(on_end=on_end, on_start=on_start)
.map()
)
data = [{"name": "one"}, {"name": "two"}]
chain.invoke(data, config={"max_concurrency": 1})
assert len(shared_state) == 2
assert value1 in shared_state.values(), "Value not found in the dictionary."
assert value2 in shared_state.values(), "Value not found in the dictionary."
async def test_listeners_async() -> None:
from langchain_core.runnables import RunnableLambda
from langchain_core.tracers.schemas import Run
def fake_chain(inputs: dict) -> dict:
return {**inputs, "key": "extra"}
shared_state = {}
value1 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
value2 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
def on_start(run: Run) -> None:
shared_state[run.id] = {"inputs": run.inputs}
def on_end(run: Run) -> None:
shared_state[run.id]["outputs"] = run.inputs
chain: Runnable = (
RunnableLambda(fake_chain)
.with_listeners(on_end=on_end, on_start=on_start)
.map()
)
data = [{"name": "one"}, {"name": "two"}]
await chain.ainvoke(data, config={"max_concurrency": 1})
assert len(shared_state) == 2
assert value1 in shared_state.values(), "Value not found in the dictionary."
assert value2 in shared_state.values(), "Value not found in the dictionary."