core: Fix test_stream_error_callback (#30228)

Fixes #29436

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Christophe Bornet 2025-03-31 16:37:22 +02:00 committed by GitHub
parent b9e19c5f97
commit 8395abbb42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 48 additions and 35 deletions

View File

@ -111,22 +111,26 @@ async def test_stream_error_callback() -> None:
else:
assert llm_result.generations[0][0].text == message[:i]
for i in range(2):
for i in range(len(message)):
llm = FakeListChatModel(
responses=[message],
error_on_chunk_number=i,
)
cb_async = FakeAsyncCallbackHandler()
llm_astream = llm.astream("Dummy message", config={"callbacks": [cb_async]})
for _ in range(i):
await llm_astream.__anext__()
with pytest.raises(FakeListChatModelError):
cb_async = FakeAsyncCallbackHandler()
async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
pass
eval_response(cb_async, i)
await llm_astream.__anext__()
eval_response(cb_async, i)
cb_sync = FakeCallbackHandler()
for _ in llm.stream("Dumy message", callbacks=[cb_sync]):
pass
eval_response(cb_sync, i)
cb_sync = FakeCallbackHandler()
llm_stream = llm.stream("Dumy message", config={"callbacks": [cb_sync]})
for _ in range(i):
next(llm_stream)
with pytest.raises(FakeListChatModelError):
next(llm_stream)
eval_response(cb_sync, i)
async def test_astream_fallback_to_ainvoke() -> None:

View File

@ -7,8 +7,11 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseLLM, FakeListLLM, FakeStreamingListLLM
from langchain_core.language_models.fake import FakeListLLMError
from langchain_core.language_models import (
LLM,
BaseLLM,
FakeListLLM,
)
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.tracers.context import collect_runs
from tests.unit_tests.fake.callbacks import (
@ -93,34 +96,40 @@ async def test_async_batch_size() -> None:
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
async def test_stream_error_callback() -> None:
message = "test"
async def test_error_callback() -> None:
class FailingLLMError(Exception):
"""FailingLLMError"""
def eval_response(callback: BaseFakeCallbackHandler, i: int) -> None:
class FailingLLM(LLM):
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "failing-llm"
def _call(
self,
prompt: str,
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
raise FailingLLMError
def eval_response(callback: BaseFakeCallbackHandler) -> None:
assert callback.errors == 1
assert len(callback.errors_args) == 1
llm_result: LLMResult = callback.errors_args[0]["kwargs"]["response"]
if i == 0:
assert llm_result.generations == []
else:
assert llm_result.generations[0][0].text == message[:i]
assert isinstance(callback.errors_args[0]["args"][0], FailingLLMError)
for i in range(2):
llm = FakeStreamingListLLM(
responses=[message],
error_on_chunk_number=i,
)
with pytest.raises(FakeListLLMError):
cb_async = FakeAsyncCallbackHandler()
async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
pass
eval_response(cb_async, i)
llm = FailingLLM()
cb_async = FakeAsyncCallbackHandler()
with pytest.raises(FailingLLMError):
await llm.ainvoke("Dummy message", config={"callbacks": [cb_async]})
eval_response(cb_async)
cb_sync = FakeCallbackHandler()
for _ in llm.stream("Dumy message", callbacks=[cb_sync]):
pass
eval_response(cb_sync, i)
cb_sync = FakeCallbackHandler()
with pytest.raises(FailingLLMError):
llm.invoke("Dummy message", config={"callbacks": [cb_sync]})
eval_response(cb_sync)
async def test_astream_fallback_to_ainvoke() -> None: