mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-27 19:46:55 +00:00
core: Fix test_stream_error_callback (#30228)
Fixes #29436 --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
b9e19c5f97
commit
8395abbb42
@ -111,22 +111,26 @@ async def test_stream_error_callback() -> None:
|
||||
else:
|
||||
assert llm_result.generations[0][0].text == message[:i]
|
||||
|
||||
for i in range(2):
|
||||
for i in range(len(message)):
|
||||
llm = FakeListChatModel(
|
||||
responses=[message],
|
||||
error_on_chunk_number=i,
|
||||
)
|
||||
cb_async = FakeAsyncCallbackHandler()
|
||||
llm_astream = llm.astream("Dummy message", config={"callbacks": [cb_async]})
|
||||
for _ in range(i):
|
||||
await llm_astream.__anext__()
|
||||
with pytest.raises(FakeListChatModelError):
|
||||
cb_async = FakeAsyncCallbackHandler()
|
||||
async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
|
||||
pass
|
||||
eval_response(cb_async, i)
|
||||
await llm_astream.__anext__()
|
||||
eval_response(cb_async, i)
|
||||
|
||||
cb_sync = FakeCallbackHandler()
|
||||
for _ in llm.stream("Dumy message", callbacks=[cb_sync]):
|
||||
pass
|
||||
|
||||
eval_response(cb_sync, i)
|
||||
cb_sync = FakeCallbackHandler()
|
||||
llm_stream = llm.stream("Dumy message", config={"callbacks": [cb_sync]})
|
||||
for _ in range(i):
|
||||
next(llm_stream)
|
||||
with pytest.raises(FakeListChatModelError):
|
||||
next(llm_stream)
|
||||
eval_response(cb_sync, i)
|
||||
|
||||
|
||||
async def test_astream_fallback_to_ainvoke() -> None:
|
||||
|
@ -7,8 +7,11 @@ from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import BaseLLM, FakeListLLM, FakeStreamingListLLM
|
||||
from langchain_core.language_models.fake import FakeListLLMError
|
||||
from langchain_core.language_models import (
|
||||
LLM,
|
||||
BaseLLM,
|
||||
FakeListLLM,
|
||||
)
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from tests.unit_tests.fake.callbacks import (
|
||||
@ -93,34 +96,40 @@ async def test_async_batch_size() -> None:
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||
|
||||
|
||||
async def test_stream_error_callback() -> None:
|
||||
message = "test"
|
||||
async def test_error_callback() -> None:
|
||||
class FailingLLMError(Exception):
|
||||
"""FailingLLMError"""
|
||||
|
||||
def eval_response(callback: BaseFakeCallbackHandler, i: int) -> None:
|
||||
class FailingLLM(LLM):
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "failing-llm"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
raise FailingLLMError
|
||||
|
||||
def eval_response(callback: BaseFakeCallbackHandler) -> None:
|
||||
assert callback.errors == 1
|
||||
assert len(callback.errors_args) == 1
|
||||
llm_result: LLMResult = callback.errors_args[0]["kwargs"]["response"]
|
||||
if i == 0:
|
||||
assert llm_result.generations == []
|
||||
else:
|
||||
assert llm_result.generations[0][0].text == message[:i]
|
||||
assert isinstance(callback.errors_args[0]["args"][0], FailingLLMError)
|
||||
|
||||
for i in range(2):
|
||||
llm = FakeStreamingListLLM(
|
||||
responses=[message],
|
||||
error_on_chunk_number=i,
|
||||
)
|
||||
with pytest.raises(FakeListLLMError):
|
||||
cb_async = FakeAsyncCallbackHandler()
|
||||
async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
|
||||
pass
|
||||
eval_response(cb_async, i)
|
||||
llm = FailingLLM()
|
||||
cb_async = FakeAsyncCallbackHandler()
|
||||
with pytest.raises(FailingLLMError):
|
||||
await llm.ainvoke("Dummy message", config={"callbacks": [cb_async]})
|
||||
eval_response(cb_async)
|
||||
|
||||
cb_sync = FakeCallbackHandler()
|
||||
for _ in llm.stream("Dumy message", callbacks=[cb_sync]):
|
||||
pass
|
||||
|
||||
eval_response(cb_sync, i)
|
||||
cb_sync = FakeCallbackHandler()
|
||||
with pytest.raises(FailingLLMError):
|
||||
llm.invoke("Dummy message", config={"callbacks": [cb_sync]})
|
||||
eval_response(cb_sync)
|
||||
|
||||
|
||||
async def test_astream_fallback_to_ainvoke() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user