From 82fb56b79cc8862607a6255d5d4271b5e1725f24 Mon Sep 17 00:00:00 2001 From: vamseeyarla Date: Fri, 18 Aug 2023 14:26:12 -0400 Subject: [PATCH] Issue 9401 - SequentialChain runs the same callbacks over and over in async mode (#9452) Issue: https://github.com/langchain-ai/langchain/issues/9401 In the Async mode, SequentialChain implementation seems to run the same callbacks over and over since it is re-using the same callbacks object. Langchain version: 0.0.264, master The implementation of this aysnc route differs from the sync route and sync approach follows the right pattern of generating a new callbacks object instead of re-using the old one and thus avoiding the cascading run of callbacks at each step. Async mode: ``` _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() callbacks = _run_manager.get_child() ... for i, chain in enumerate(self.chains): _input = await chain.arun(_input, callbacks=callbacks) ... ``` Regular mode: ``` _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() for i, chain in enumerate(self.chains): _input = chain.run(_input, callbacks=_run_manager.get_child(f"step_{i+1}")) ... ``` Notice how we are reusing the callbacks object in the Async code which will have a cascading effect as we run through the chain. It runs the same callbacks over and over resulting in issues. Solution: Define the async function in the same pattern as the regular one and added tests. --------- Co-authored-by: vamsee_yarlagadda --- libs/langchain/langchain/chains/sequential.py | 5 +- .../unit_tests/chains/test_sequential.py | 47 ++++++++++++++++++- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/chains/sequential.py b/libs/langchain/langchain/chains/sequential.py index 7acf33ef08d..934c3a9eee8 100644 --- a/libs/langchain/langchain/chains/sequential.py +++ b/libs/langchain/langchain/chains/sequential.py @@ -190,11 +190,12 @@ class SimpleSequentialChain(Chain): run_manager: Optional[AsyncCallbackManagerForChainRun] = None, ) -> Dict[str, Any]: _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() - callbacks = _run_manager.get_child() _input = inputs[self.input_key] color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))]) for i, chain in enumerate(self.chains): - _input = await chain.arun(_input, callbacks=callbacks) + _input = await chain.arun( + _input, callbacks=_run_manager.get_child(f"step_{i+1}") + ) if self.strip_outputs: _input = _input.strip() await _run_manager.on_text( diff --git a/libs/langchain/tests/unit_tests/chains/test_sequential.py b/libs/langchain/tests/unit_tests/chains/test_sequential.py index 12f72c6fe7d..861f013f5fb 100644 --- a/libs/langchain/tests/unit_tests/chains/test_sequential.py +++ b/libs/langchain/tests/unit_tests/chains/test_sequential.py @@ -3,11 +3,15 @@ from typing import Dict, List, Optional import pytest -from langchain.callbacks.manager import CallbackManagerForChainRun +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) from langchain.chains.base import Chain from langchain.chains.sequential import SequentialChain, SimpleSequentialChain from langchain.memory import ConversationBufferMemory from langchain.memory.simple import SimpleMemory +from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler class FakeChain(Chain): @@ -37,6 +41,17 @@ class FakeChain(Chain): outputs[var] = f"{' '.join(variables)}foo" return outputs + async def _acall( + self, + inputs: Dict[str, str], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + outputs = {} + for var in self.output_variables: + variables = [inputs[k] for k in self.input_variables] + outputs[var] = f"{' '.join(variables)}foo" + return outputs + def test_sequential_usage_single_inputs() -> None: """Test sequential on single input chains.""" @@ -165,6 +180,36 @@ def test_simple_sequential_functionality() -> None: assert output == expected_output +@pytest.mark.asyncio +@pytest.mark.parametrize("isAsync", [False, True]) +async def test_simple_sequential_functionality_with_callbacks(isAsync: bool) -> None: + """Test simple sequential functionality.""" + handler_1 = FakeCallbackHandler() + handler_2 = FakeCallbackHandler() + handler_3 = FakeCallbackHandler() + chain_1 = FakeChain( + input_variables=["foo"], output_variables=["bar"], callbacks=[handler_1] + ) + chain_2 = FakeChain( + input_variables=["bar"], output_variables=["baz"], callbacks=[handler_2] + ) + chain_3 = FakeChain( + input_variables=["jack"], output_variables=["baf"], callbacks=[handler_3] + ) + chain = SimpleSequentialChain(chains=[chain_1, chain_2, chain_3]) + if isAsync: + output = await chain.ainvoke({"input": "123"}) + else: + output = chain({"input": "123"}) + expected_output = {"output": "123foofoofoo", "input": "123"} + assert output == expected_output + # Check that each of the callbacks were invoked once per the entire run + for handler in [handler_1, handler_2, handler_3]: + assert handler.starts == 1 + assert handler.ends == 1 + assert handler.errors == 0 + + def test_multi_input_errors() -> None: """Test simple sequential errors if multiple input variables are expected.""" chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar"])