mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-02 11:39:18 +00:00
core[patch]: add response kwarg to on_llm_error
# Dependencies None # Twitter handle @HKydlicek --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -14,6 +14,7 @@ class BaseFakeCallbackHandler(BaseModel):
|
||||
starts: int = 0
|
||||
ends: int = 0
|
||||
errors: int = 0
|
||||
errors_args: List[Any] = []
|
||||
text: int = 0
|
||||
ignore_llm_: bool = False
|
||||
ignore_chain_: bool = False
|
||||
@@ -52,8 +53,9 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
||||
self.llm_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
def on_llm_error_common(self) -> None:
|
||||
def on_llm_error_common(self, *args: Any, **kwargs: Any) -> None:
|
||||
self.errors += 1
|
||||
self.errors_args.append({"args": args, "kwargs": kwargs})
|
||||
|
||||
def on_llm_new_token_common(self) -> None:
|
||||
self.llm_streams += 1
|
||||
@@ -160,7 +162,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_llm_error_common()
|
||||
self.on_llm_error_common(*args, **kwargs)
|
||||
|
||||
def on_retry(
|
||||
self,
|
||||
@@ -322,7 +324,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_llm_error_common()
|
||||
self.on_llm_error_common(*args, **kwargs)
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
|
@@ -45,6 +45,7 @@ class FakeListChatModel(SimpleChatModel):
|
||||
responses: List
|
||||
sleep: Optional[float] = None
|
||||
i: int = 0
|
||||
error_on_chunk_number: Optional[int] = None
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
@@ -77,9 +78,15 @@ class FakeListChatModel(SimpleChatModel):
|
||||
self.i += 1
|
||||
else:
|
||||
self.i = 0
|
||||
for c in response:
|
||||
for i_c, c in enumerate(response):
|
||||
if self.sleep is not None:
|
||||
time.sleep(self.sleep)
|
||||
if (
|
||||
self.error_on_chunk_number is not None
|
||||
and i_c == self.error_on_chunk_number
|
||||
):
|
||||
raise Exception("Fake error")
|
||||
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
||||
|
||||
async def _astream(
|
||||
@@ -94,9 +101,14 @@ class FakeListChatModel(SimpleChatModel):
|
||||
self.i += 1
|
||||
else:
|
||||
self.i = 0
|
||||
for c in response:
|
||||
for i_c, c in enumerate(response):
|
||||
if self.sleep is not None:
|
||||
await asyncio.sleep(self.sleep)
|
||||
if (
|
||||
self.error_on_chunk_number is not None
|
||||
and i_c == self.error_on_chunk_number
|
||||
):
|
||||
raise Exception("Fake error")
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
||||
|
||||
@property
|
||||
|
@@ -60,6 +60,8 @@ class FakeListLLM(LLM):
|
||||
class FakeStreamingListLLM(FakeListLLM):
|
||||
"""Fake streaming list LLM for testing purposes."""
|
||||
|
||||
error_on_chunk_number: Optional[int] = None
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
@@ -69,9 +71,15 @@ class FakeStreamingListLLM(FakeListLLM):
|
||||
**kwargs: Any,
|
||||
) -> Iterator[str]:
|
||||
result = self.invoke(input, config)
|
||||
for c in result:
|
||||
for i_c, c in enumerate(result):
|
||||
if self.sleep is not None:
|
||||
time.sleep(self.sleep)
|
||||
|
||||
if (
|
||||
self.error_on_chunk_number is not None
|
||||
and i_c == self.error_on_chunk_number
|
||||
):
|
||||
raise Exception("Fake error")
|
||||
yield c
|
||||
|
||||
async def astream(
|
||||
@@ -83,7 +91,13 @@ class FakeStreamingListLLM(FakeListLLM):
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[str]:
|
||||
result = await self.ainvoke(input, config)
|
||||
for c in result:
|
||||
for i_c, c in enumerate(result):
|
||||
if self.sleep is not None:
|
||||
await asyncio.sleep(self.sleep)
|
||||
|
||||
if (
|
||||
self.error_on_chunk_number is not None
|
||||
and i_c == self.error_on_chunk_number
|
||||
):
|
||||
raise Exception("Fake error")
|
||||
yield c
|
||||
|
@@ -1,8 +1,15 @@
|
||||
"""Test base chat model."""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.outputs.llm_result import LLMResult
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from tests.unit_tests.fake.callbacks import (
|
||||
BaseFakeCallbackHandler,
|
||||
FakeAsyncCallbackHandler,
|
||||
FakeCallbackHandler,
|
||||
)
|
||||
from tests.unit_tests.fake.chat_model import FakeListChatModel
|
||||
|
||||
|
||||
@@ -69,3 +76,33 @@ async def test_async_batch_size(messages: list, messages_2: list) -> None:
|
||||
pass
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||
|
||||
|
||||
async def test_stream_error_callback() -> None:
|
||||
message = "test"
|
||||
|
||||
def eval_response(callback: BaseFakeCallbackHandler, i: int) -> None:
|
||||
assert callback.errors == 1
|
||||
assert len(callback.errors_args) == 1
|
||||
llm_result: LLMResult = callback.errors_args[0]["kwargs"]["response"]
|
||||
if i == 0:
|
||||
assert llm_result.generations == []
|
||||
else:
|
||||
assert llm_result.generations[0][0].text == message[:i]
|
||||
|
||||
for i in range(0, 2):
|
||||
llm = FakeListChatModel(
|
||||
responses=[message],
|
||||
error_on_chunk_number=i,
|
||||
)
|
||||
with pytest.raises(Exception):
|
||||
cb_async = FakeAsyncCallbackHandler()
|
||||
async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
|
||||
pass
|
||||
eval_response(cb_async, i)
|
||||
|
||||
cb_sync = FakeCallbackHandler()
|
||||
for _ in llm.stream("Dumy message", callbacks=[cb_sync]):
|
||||
pass
|
||||
|
||||
eval_response(cb_sync, i)
|
||||
|
@@ -1,5 +1,13 @@
|
||||
import pytest
|
||||
|
||||
from langchain_core.outputs.llm_result import LLMResult
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from tests.unit_tests.fake.llm import FakeListLLM
|
||||
from tests.unit_tests.fake.callbacks import (
|
||||
BaseFakeCallbackHandler,
|
||||
FakeAsyncCallbackHandler,
|
||||
FakeCallbackHandler,
|
||||
)
|
||||
from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM
|
||||
|
||||
|
||||
def test_batch() -> None:
|
||||
@@ -75,3 +83,33 @@ async def test_async_batch_size() -> None:
|
||||
pass
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||
|
||||
|
||||
async def test_stream_error_callback() -> None:
|
||||
message = "test"
|
||||
|
||||
def eval_response(callback: BaseFakeCallbackHandler, i: int) -> None:
|
||||
assert callback.errors == 1
|
||||
assert len(callback.errors_args) == 1
|
||||
llm_result: LLMResult = callback.errors_args[0]["kwargs"]["response"]
|
||||
if i == 0:
|
||||
assert llm_result.generations == []
|
||||
else:
|
||||
assert llm_result.generations[0][0].text == message[:i]
|
||||
|
||||
for i in range(0, 2):
|
||||
llm = FakeStreamingListLLM(
|
||||
responses=[message],
|
||||
error_on_chunk_number=i,
|
||||
)
|
||||
with pytest.raises(Exception):
|
||||
cb_async = FakeAsyncCallbackHandler()
|
||||
async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
|
||||
pass
|
||||
eval_response(cb_async, i)
|
||||
|
||||
cb_sync = FakeCallbackHandler()
|
||||
for _ in llm.stream("Dumy message", callbacks=[cb_sync]):
|
||||
pass
|
||||
|
||||
eval_response(cb_sync, i)
|
||||
|
Reference in New Issue
Block a user