mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 08:58:48 +00:00
core: Fix test_stream_error_callback (#30228)
Fixes #29436 --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
b9e19c5f97
commit
8395abbb42
@ -111,21 +111,25 @@ async def test_stream_error_callback() -> None:
|
|||||||
else:
|
else:
|
||||||
assert llm_result.generations[0][0].text == message[:i]
|
assert llm_result.generations[0][0].text == message[:i]
|
||||||
|
|
||||||
for i in range(2):
|
for i in range(len(message)):
|
||||||
llm = FakeListChatModel(
|
llm = FakeListChatModel(
|
||||||
responses=[message],
|
responses=[message],
|
||||||
error_on_chunk_number=i,
|
error_on_chunk_number=i,
|
||||||
)
|
)
|
||||||
with pytest.raises(FakeListChatModelError):
|
|
||||||
cb_async = FakeAsyncCallbackHandler()
|
cb_async = FakeAsyncCallbackHandler()
|
||||||
async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
|
llm_astream = llm.astream("Dummy message", config={"callbacks": [cb_async]})
|
||||||
pass
|
for _ in range(i):
|
||||||
|
await llm_astream.__anext__()
|
||||||
|
with pytest.raises(FakeListChatModelError):
|
||||||
|
await llm_astream.__anext__()
|
||||||
eval_response(cb_async, i)
|
eval_response(cb_async, i)
|
||||||
|
|
||||||
cb_sync = FakeCallbackHandler()
|
cb_sync = FakeCallbackHandler()
|
||||||
for _ in llm.stream("Dumy message", callbacks=[cb_sync]):
|
llm_stream = llm.stream("Dumy message", config={"callbacks": [cb_sync]})
|
||||||
pass
|
for _ in range(i):
|
||||||
|
next(llm_stream)
|
||||||
|
with pytest.raises(FakeListChatModelError):
|
||||||
|
next(llm_stream)
|
||||||
eval_response(cb_sync, i)
|
eval_response(cb_sync, i)
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,8 +7,11 @@ from langchain_core.callbacks import (
|
|||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain_core.language_models import BaseLLM, FakeListLLM, FakeStreamingListLLM
|
from langchain_core.language_models import (
|
||||||
from langchain_core.language_models.fake import FakeListLLMError
|
LLM,
|
||||||
|
BaseLLM,
|
||||||
|
FakeListLLM,
|
||||||
|
)
|
||||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||||
from langchain_core.tracers.context import collect_runs
|
from langchain_core.tracers.context import collect_runs
|
||||||
from tests.unit_tests.fake.callbacks import (
|
from tests.unit_tests.fake.callbacks import (
|
||||||
@ -93,34 +96,40 @@ async def test_async_batch_size() -> None:
|
|||||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_stream_error_callback() -> None:
|
async def test_error_callback() -> None:
|
||||||
message = "test"
|
class FailingLLMError(Exception):
|
||||||
|
"""FailingLLMError"""
|
||||||
|
|
||||||
def eval_response(callback: BaseFakeCallbackHandler, i: int) -> None:
|
class FailingLLM(LLM):
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of llm."""
|
||||||
|
return "failing-llm"
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
raise FailingLLMError
|
||||||
|
|
||||||
|
def eval_response(callback: BaseFakeCallbackHandler) -> None:
|
||||||
assert callback.errors == 1
|
assert callback.errors == 1
|
||||||
assert len(callback.errors_args) == 1
|
assert len(callback.errors_args) == 1
|
||||||
llm_result: LLMResult = callback.errors_args[0]["kwargs"]["response"]
|
assert isinstance(callback.errors_args[0]["args"][0], FailingLLMError)
|
||||||
if i == 0:
|
|
||||||
assert llm_result.generations == []
|
|
||||||
else:
|
|
||||||
assert llm_result.generations[0][0].text == message[:i]
|
|
||||||
|
|
||||||
for i in range(2):
|
llm = FailingLLM()
|
||||||
llm = FakeStreamingListLLM(
|
|
||||||
responses=[message],
|
|
||||||
error_on_chunk_number=i,
|
|
||||||
)
|
|
||||||
with pytest.raises(FakeListLLMError):
|
|
||||||
cb_async = FakeAsyncCallbackHandler()
|
cb_async = FakeAsyncCallbackHandler()
|
||||||
async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
|
with pytest.raises(FailingLLMError):
|
||||||
pass
|
await llm.ainvoke("Dummy message", config={"callbacks": [cb_async]})
|
||||||
eval_response(cb_async, i)
|
eval_response(cb_async)
|
||||||
|
|
||||||
cb_sync = FakeCallbackHandler()
|
cb_sync = FakeCallbackHandler()
|
||||||
for _ in llm.stream("Dumy message", callbacks=[cb_sync]):
|
with pytest.raises(FailingLLMError):
|
||||||
pass
|
llm.invoke("Dummy message", config={"callbacks": [cb_sync]})
|
||||||
|
eval_response(cb_sync)
|
||||||
eval_response(cb_sync, i)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_astream_fallback_to_ainvoke() -> None:
|
async def test_astream_fallback_to_ainvoke() -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user