mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 03:01:29 +00:00
Issue 9401 - SequentialChain runs the same callbacks over and over in async mode (#9452)
Issue: https://github.com/langchain-ai/langchain/issues/9401 In the Async mode, SequentialChain implementation seems to run the same callbacks over and over since it is re-using the same callbacks object. Langchain version: 0.0.264, master The implementation of this aysnc route differs from the sync route and sync approach follows the right pattern of generating a new callbacks object instead of re-using the old one and thus avoiding the cascading run of callbacks at each step. Async mode: ``` _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() callbacks = _run_manager.get_child() ... for i, chain in enumerate(self.chains): _input = await chain.arun(_input, callbacks=callbacks) ... ``` Regular mode: ``` _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() for i, chain in enumerate(self.chains): _input = chain.run(_input, callbacks=_run_manager.get_child(f"step_{i+1}")) ... ``` Notice how we are reusing the callbacks object in the Async code which will have a cascading effect as we run through the chain. It runs the same callbacks over and over resulting in issues. Solution: Define the async function in the same pattern as the regular one and added tests. --------- Co-authored-by: vamsee_yarlagadda <vamsee.y@airbnb.com>
This commit is contained in:
parent
99e5eaa9b1
commit
82fb56b79c
@ -190,11 +190,12 @@ class SimpleSequentialChain(Chain):
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
_input = inputs[self.input_key]
|
||||
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
|
||||
for i, chain in enumerate(self.chains):
|
||||
_input = await chain.arun(_input, callbacks=callbacks)
|
||||
_input = await chain.arun(
|
||||
_input, callbacks=_run_manager.get_child(f"step_{i+1}")
|
||||
)
|
||||
if self.strip_outputs:
|
||||
_input = _input.strip()
|
||||
await _run_manager.on_text(
|
||||
|
@ -3,11 +3,15 @@ from typing import Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.memory.simple import SimpleMemory
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
class FakeChain(Chain):
|
||||
@ -37,6 +41,17 @@ class FakeChain(Chain):
|
||||
outputs[var] = f"{' '.join(variables)}foo"
|
||||
return outputs
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
outputs = {}
|
||||
for var in self.output_variables:
|
||||
variables = [inputs[k] for k in self.input_variables]
|
||||
outputs[var] = f"{' '.join(variables)}foo"
|
||||
return outputs
|
||||
|
||||
|
||||
def test_sequential_usage_single_inputs() -> None:
|
||||
"""Test sequential on single input chains."""
|
||||
@ -165,6 +180,36 @@ def test_simple_sequential_functionality() -> None:
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("isAsync", [False, True])
|
||||
async def test_simple_sequential_functionality_with_callbacks(isAsync: bool) -> None:
|
||||
"""Test simple sequential functionality."""
|
||||
handler_1 = FakeCallbackHandler()
|
||||
handler_2 = FakeCallbackHandler()
|
||||
handler_3 = FakeCallbackHandler()
|
||||
chain_1 = FakeChain(
|
||||
input_variables=["foo"], output_variables=["bar"], callbacks=[handler_1]
|
||||
)
|
||||
chain_2 = FakeChain(
|
||||
input_variables=["bar"], output_variables=["baz"], callbacks=[handler_2]
|
||||
)
|
||||
chain_3 = FakeChain(
|
||||
input_variables=["jack"], output_variables=["baf"], callbacks=[handler_3]
|
||||
)
|
||||
chain = SimpleSequentialChain(chains=[chain_1, chain_2, chain_3])
|
||||
if isAsync:
|
||||
output = await chain.ainvoke({"input": "123"})
|
||||
else:
|
||||
output = chain({"input": "123"})
|
||||
expected_output = {"output": "123foofoofoo", "input": "123"}
|
||||
assert output == expected_output
|
||||
# Check that each of the callbacks were invoked once per the entire run
|
||||
for handler in [handler_1, handler_2, handler_3]:
|
||||
assert handler.starts == 1
|
||||
assert handler.ends == 1
|
||||
assert handler.errors == 0
|
||||
|
||||
|
||||
def test_multi_input_errors() -> None:
|
||||
"""Test simple sequential errors if multiple input variables are expected."""
|
||||
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])
|
||||
|
Loading…
Reference in New Issue
Block a user