From 8395abbb42abb7c35473028720cd7e41fb8f4710 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Mon, 31 Mar 2025 16:37:22 +0200 Subject: [PATCH] core: Fix test_stream_error_callback (#30228) Fixes #29436 --------- Co-authored-by: Eugene Yurtsev --- .../language_models/chat_models/test_base.py | 24 ++++---- .../language_models/llms/test_base.py | 59 +++++++++++-------- 2 files changed, 48 insertions(+), 35 deletions(-) diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index 284ffb23d18..01f0245f303 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -111,22 +111,26 @@ async def test_stream_error_callback() -> None: else: assert llm_result.generations[0][0].text == message[:i] - for i in range(2): + for i in range(len(message)): llm = FakeListChatModel( responses=[message], error_on_chunk_number=i, ) + cb_async = FakeAsyncCallbackHandler() + llm_astream = llm.astream("Dummy message", config={"callbacks": [cb_async]}) + for _ in range(i): + await llm_astream.__anext__() with pytest.raises(FakeListChatModelError): - cb_async = FakeAsyncCallbackHandler() - async for _ in llm.astream("Dummy message", callbacks=[cb_async]): - pass - eval_response(cb_async, i) + await llm_astream.__anext__() + eval_response(cb_async, i) - cb_sync = FakeCallbackHandler() - for _ in llm.stream("Dumy message", callbacks=[cb_sync]): - pass - - eval_response(cb_sync, i) + cb_sync = FakeCallbackHandler() + llm_stream = llm.stream("Dumy message", config={"callbacks": [cb_sync]}) + for _ in range(i): + next(llm_stream) + with pytest.raises(FakeListChatModelError): + next(llm_stream) + eval_response(cb_sync, i) async def test_astream_fallback_to_ainvoke() -> None: diff --git a/libs/core/tests/unit_tests/language_models/llms/test_base.py b/libs/core/tests/unit_tests/language_models/llms/test_base.py index 0f28762bb21..a5741b2225f 100644 --- a/libs/core/tests/unit_tests/language_models/llms/test_base.py +++ b/libs/core/tests/unit_tests/language_models/llms/test_base.py @@ -7,8 +7,11 @@ from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain_core.language_models import BaseLLM, FakeListLLM, FakeStreamingListLLM -from langchain_core.language_models.fake import FakeListLLMError +from langchain_core.language_models import ( + LLM, + BaseLLM, + FakeListLLM, +) from langchain_core.outputs import Generation, GenerationChunk, LLMResult from langchain_core.tracers.context import collect_runs from tests.unit_tests.fake.callbacks import ( @@ -93,34 +96,40 @@ async def test_async_batch_size() -> None: assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1 -async def test_stream_error_callback() -> None: - message = "test" +async def test_error_callback() -> None: + class FailingLLMError(Exception): + """FailingLLMError""" - def eval_response(callback: BaseFakeCallbackHandler, i: int) -> None: + class FailingLLM(LLM): + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "failing-llm" + + def _call( + self, + prompt: str, + stop: Optional[list[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + raise FailingLLMError + + def eval_response(callback: BaseFakeCallbackHandler) -> None: assert callback.errors == 1 assert len(callback.errors_args) == 1 - llm_result: LLMResult = callback.errors_args[0]["kwargs"]["response"] - if i == 0: - assert llm_result.generations == [] - else: - assert llm_result.generations[0][0].text == message[:i] + assert isinstance(callback.errors_args[0]["args"][0], FailingLLMError) - for i in range(2): - llm = FakeStreamingListLLM( - responses=[message], - error_on_chunk_number=i, - ) - with pytest.raises(FakeListLLMError): - cb_async = FakeAsyncCallbackHandler() - async for _ in llm.astream("Dummy message", callbacks=[cb_async]): - pass - eval_response(cb_async, i) + llm = FailingLLM() + cb_async = FakeAsyncCallbackHandler() + with pytest.raises(FailingLLMError): + await llm.ainvoke("Dummy message", config={"callbacks": [cb_async]}) + eval_response(cb_async) - cb_sync = FakeCallbackHandler() - for _ in llm.stream("Dumy message", callbacks=[cb_sync]): - pass - - eval_response(cb_sync, i) + cb_sync = FakeCallbackHandler() + with pytest.raises(FailingLLMError): + llm.invoke("Dummy message", config={"callbacks": [cb_sync]}) + eval_response(cb_sync) async def test_astream_fallback_to_ainvoke() -> None: