diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 4950302777f..fc7717e8547 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -1878,7 +1878,7 @@ class Runnable(Generic[Input, Output], ABC): final_output_supported = False else: final_output = chunk - except StopIteration: + except (StopIteration, GeneratorExit): pass for ichunk in input_for_tracing: if final_input_supported: @@ -1892,8 +1892,6 @@ class Runnable(Generic[Input, Output], ABC): final_input_supported = False else: final_input = ichunk - except GeneratorExit: - run_manager.on_chain_end(final_output, inputs=final_input) except BaseException as e: run_manager.on_chain_error(e, inputs=final_input) raise diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 7f06928f4f2..5861b0fc1ab 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -5706,3 +5706,39 @@ async def test_listeners_async() -> None: assert len(shared_state) == 2 assert value1 in shared_state.values(), "Value not found in the dictionary." assert value2 in shared_state.values(), "Value not found in the dictionary." + + +async def test_closing_iterator_doesnt_raise_error() -> None: + """Test that closing an iterator calls on_chain_end rather than on_chain_error.""" + import time + + from langchain_core.callbacks import BaseCallbackHandler + from langchain_core.language_models.fake_chat_models import GenericFakeChatModel + from langchain_core.output_parsers import StrOutputParser + + on_chain_error_triggered = False + + class MyHandler(BaseCallbackHandler): + async def on_chain_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when chain errors.""" + nonlocal on_chain_error_triggered + on_chain_error_triggered = True + + llm = GenericFakeChatModel(messages=iter(["hi there"])) + chain = llm | StrOutputParser() + chain_ = chain.with_config({"callbacks": [MyHandler()]}) + st = chain_.stream("hello") + next(st) + # This is a generator so close is defined on it. + st.close() # type: ignore + # Wait for a bit to make sure that the callback is called. + time.sleep(0.05) + assert on_chain_error_triggered is False