Issue 9401 - SequentialChain runs the same callbacks over and over in async mode (#9452)

Issue: https://github.com/langchain-ai/langchain/issues/9401

In the Async mode, SequentialChain implementation seems to run the same
callbacks over and over since it is re-using the same callbacks object.

Langchain version: 0.0.264, master

The implementation of this aysnc route differs from the sync route and
sync approach follows the right pattern of generating a new callbacks
object instead of re-using the old one and thus avoiding the cascading
run of callbacks at each step.

Async mode:
```
        _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
        callbacks = _run_manager.get_child()
        ...
        for i, chain in enumerate(self.chains):
            _input = await chain.arun(_input, callbacks=callbacks)
            ...
```

Regular mode:
```
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        for i, chain in enumerate(self.chains):
            _input = chain.run(_input, callbacks=_run_manager.get_child(f"step_{i+1}"))
            ...
```

Notice how we are reusing the callbacks object in the Async code which
will have a cascading effect as we run through the chain. It runs the
same callbacks over and over resulting in issues.

Solution:
Define the async function in the same pattern as the regular one and
added tests.
---------

Co-authored-by: vamsee_yarlagadda <vamsee.y@airbnb.com>
This commit is contained in:
vamseeyarla 2023-08-18 14:26:12 -04:00 committed by GitHub
parent 99e5eaa9b1
commit 82fb56b79c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 49 additions and 3 deletions

View File

@ -190,11 +190,12 @@ class SimpleSequentialChain(Chain):
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
_input = inputs[self.input_key]
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
for i, chain in enumerate(self.chains):
_input = await chain.arun(_input, callbacks=callbacks)
_input = await chain.arun(
_input, callbacks=_run_manager.get_child(f"step_{i+1}")
)
if self.strip_outputs:
_input = _input.strip()
await _run_manager.on_text(

View File

@ -3,11 +3,15 @@ from typing import Dict, List, Optional
import pytest
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
from langchain.memory import ConversationBufferMemory
from langchain.memory.simple import SimpleMemory
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
class FakeChain(Chain):
@ -37,6 +41,17 @@ class FakeChain(Chain):
outputs[var] = f"{' '.join(variables)}foo"
return outputs
async def _acall(
self,
inputs: Dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
outputs = {}
for var in self.output_variables:
variables = [inputs[k] for k in self.input_variables]
outputs[var] = f"{' '.join(variables)}foo"
return outputs
def test_sequential_usage_single_inputs() -> None:
"""Test sequential on single input chains."""
@ -165,6 +180,36 @@ def test_simple_sequential_functionality() -> None:
assert output == expected_output
@pytest.mark.asyncio
@pytest.mark.parametrize("isAsync", [False, True])
async def test_simple_sequential_functionality_with_callbacks(isAsync: bool) -> None:
"""Test simple sequential functionality."""
handler_1 = FakeCallbackHandler()
handler_2 = FakeCallbackHandler()
handler_3 = FakeCallbackHandler()
chain_1 = FakeChain(
input_variables=["foo"], output_variables=["bar"], callbacks=[handler_1]
)
chain_2 = FakeChain(
input_variables=["bar"], output_variables=["baz"], callbacks=[handler_2]
)
chain_3 = FakeChain(
input_variables=["jack"], output_variables=["baf"], callbacks=[handler_3]
)
chain = SimpleSequentialChain(chains=[chain_1, chain_2, chain_3])
if isAsync:
output = await chain.ainvoke({"input": "123"})
else:
output = chain({"input": "123"})
expected_output = {"output": "123foofoofoo", "input": "123"}
assert output == expected_output
# Check that each of the callbacks were invoked once per the entire run
for handler in [handler_1, handler_2, handler_3]:
assert handler.starts == 1
assert handler.ends == 1
assert handler.errors == 0
def test_multi_input_errors() -> None:
"""Test simple sequential errors if multiple input variables are expected."""
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])