Issue 9401 - SequentialChain runs the same callbacks over and over in async mode (#9452)

Issue: https://github.com/langchain-ai/langchain/issues/9401

In the Async mode, SequentialChain implementation seems to run the same
callbacks over and over since it is re-using the same callbacks object.

Langchain version: 0.0.264, master

The implementation of this aysnc route differs from the sync route and
sync approach follows the right pattern of generating a new callbacks
object instead of re-using the old one and thus avoiding the cascading
run of callbacks at each step.

Async mode:
```
        _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
        callbacks = _run_manager.get_child()
        ...
        for i, chain in enumerate(self.chains):
            _input = await chain.arun(_input, callbacks=callbacks)
            ...
```

Regular mode:
```
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        for i, chain in enumerate(self.chains):
            _input = chain.run(_input, callbacks=_run_manager.get_child(f"step_{i+1}"))
            ...
```

Notice how we are reusing the callbacks object in the Async code which
will have a cascading effect as we run through the chain. It runs the
same callbacks over and over resulting in issues.

Solution:
Define the async function in the same pattern as the regular one and
added tests.
---------

Co-authored-by: vamsee_yarlagadda <vamsee.y@airbnb.com>
This commit is contained in:
vamseeyarla 2023-08-18 14:26:12 -04:00 committed by GitHub
parent 99e5eaa9b1
commit 82fb56b79c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 49 additions and 3 deletions

View File

@ -190,11 +190,12 @@ class SimpleSequentialChain(Chain):
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
_input = inputs[self.input_key] _input = inputs[self.input_key]
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))]) color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
for i, chain in enumerate(self.chains): for i, chain in enumerate(self.chains):
_input = await chain.arun(_input, callbacks=callbacks) _input = await chain.arun(
_input, callbacks=_run_manager.get_child(f"step_{i+1}")
)
if self.strip_outputs: if self.strip_outputs:
_input = _input.strip() _input = _input.strip()
await _run_manager.on_text( await _run_manager.on_text(

View File

@ -3,11 +3,15 @@ from typing import Dict, List, Optional
import pytest import pytest
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
from langchain.memory import ConversationBufferMemory from langchain.memory import ConversationBufferMemory
from langchain.memory.simple import SimpleMemory from langchain.memory.simple import SimpleMemory
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
class FakeChain(Chain): class FakeChain(Chain):
@ -37,6 +41,17 @@ class FakeChain(Chain):
outputs[var] = f"{' '.join(variables)}foo" outputs[var] = f"{' '.join(variables)}foo"
return outputs return outputs
async def _acall(
self,
inputs: Dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
outputs = {}
for var in self.output_variables:
variables = [inputs[k] for k in self.input_variables]
outputs[var] = f"{' '.join(variables)}foo"
return outputs
def test_sequential_usage_single_inputs() -> None: def test_sequential_usage_single_inputs() -> None:
"""Test sequential on single input chains.""" """Test sequential on single input chains."""
@ -165,6 +180,36 @@ def test_simple_sequential_functionality() -> None:
assert output == expected_output assert output == expected_output
@pytest.mark.asyncio
@pytest.mark.parametrize("isAsync", [False, True])
async def test_simple_sequential_functionality_with_callbacks(isAsync: bool) -> None:
"""Test simple sequential functionality."""
handler_1 = FakeCallbackHandler()
handler_2 = FakeCallbackHandler()
handler_3 = FakeCallbackHandler()
chain_1 = FakeChain(
input_variables=["foo"], output_variables=["bar"], callbacks=[handler_1]
)
chain_2 = FakeChain(
input_variables=["bar"], output_variables=["baz"], callbacks=[handler_2]
)
chain_3 = FakeChain(
input_variables=["jack"], output_variables=["baf"], callbacks=[handler_3]
)
chain = SimpleSequentialChain(chains=[chain_1, chain_2, chain_3])
if isAsync:
output = await chain.ainvoke({"input": "123"})
else:
output = chain({"input": "123"})
expected_output = {"output": "123foofoofoo", "input": "123"}
assert output == expected_output
# Check that each of the callbacks were invoked once per the entire run
for handler in [handler_1, handler_2, handler_3]:
assert handler.starts == 1
assert handler.ends == 1
assert handler.errors == 0
def test_multi_input_errors() -> None: def test_multi_input_errors() -> None:
"""Test simple sequential errors if multiple input variables are expected.""" """Test simple sequential errors if multiple input variables are expected."""
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"]) chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])