core[patch]: add response kwarg to on_llm_error

# Dependencies
None

# Twitter handle
@HKydlicek

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Hynek Kydlíček 2023-12-05 00:04:48 +01:00 committed by GitHub
parent 1750cc464d
commit aa8ae31e5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 172 additions and 25 deletions

View File

@ -75,7 +75,13 @@ class LLMManagerMixin:
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run when LLM errors.""" """Run when LLM errors.
Args:
error (BaseException): The error that occurred.
kwargs (Any): Additional keyword arguments.
- response (LLMResult): The response which was generated before
the error occurred.
"""
class ChainManagerMixin: class ChainManagerMixin:
@ -351,7 +357,13 @@ class AsyncCallbackHandler(BaseCallbackHandler):
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when LLM errors.""" """Run when LLM errors.
Args:
error (BaseException): The error that occurred.
kwargs (Any): Additional keyword arguments.
- response (LLMResult): The response which was generated before
the error occurred.
"""
async def on_chain_start( async def on_chain_start(
self, self,

View File

@ -623,6 +623,9 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
Args: Args:
error (Exception or KeyboardInterrupt): The error. error (Exception or KeyboardInterrupt): The error.
kwargs (Any): Additional keyword arguments.
- response (LLMResult): The response which was generated before
the error occurred.
""" """
handle_event( handle_event(
self.handlers, self.handlers,
@ -689,6 +692,12 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
Args: Args:
error (Exception or KeyboardInterrupt): The error. error (Exception or KeyboardInterrupt): The error.
kwargs (Any): Additional keyword arguments.
- response (LLMResult): The response which was generated before
the error occurred.
""" """
await ahandle_event( await ahandle_event(
self.handlers, self.handlers,

View File

@ -223,8 +223,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
name=config.get("run_name"), name=config.get("run_name"),
batch_size=1, batch_size=1,
) )
try:
generation: Optional[ChatGenerationChunk] = None generation: Optional[ChatGenerationChunk] = None
try:
for chunk in self._stream( for chunk in self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs messages, stop=stop, run_manager=run_manager, **kwargs
): ):
@ -235,12 +235,15 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
generation += chunk generation += chunk
assert generation is not None assert generation is not None
except BaseException as e: except BaseException as e:
run_manager.on_llm_error(e) run_manager.on_llm_error(
e,
response=LLMResult(
generations=[[generation]] if generation else []
),
)
raise e raise e
else: else:
run_manager.on_llm_end( run_manager.on_llm_end(LLMResult(generations=[[generation]]))
LLMResult(generations=[[generation]]),
)
async def astream( async def astream(
self, self,
@ -277,8 +280,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
name=config.get("run_name"), name=config.get("run_name"),
batch_size=1, batch_size=1,
) )
try:
generation: Optional[ChatGenerationChunk] = None generation: Optional[ChatGenerationChunk] = None
try:
async for chunk in self._astream( async for chunk in self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs messages, stop=stop, run_manager=run_manager, **kwargs
): ):
@ -289,7 +292,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
generation += chunk generation += chunk
assert generation is not None assert generation is not None
except BaseException as e: except BaseException as e:
await run_manager.on_llm_error(e) await run_manager.on_llm_error(
e,
response=LLMResult(
generations=[[generation]] if generation else []
),
)
raise e raise e
else: else:
await run_manager.on_llm_end( await run_manager.on_llm_end(
@ -366,7 +374,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
) )
except BaseException as e: except BaseException as e:
if run_managers: if run_managers:
run_managers[i].on_llm_error(e) run_managers[i].on_llm_error(e, response=LLMResult(generations=[]))
raise e raise e
flattened_outputs = [ flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output) LLMResult(generations=[res.generations], llm_output=res.llm_output)
@ -433,7 +441,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
for i, res in enumerate(results): for i, res in enumerate(results):
if isinstance(res, BaseException): if isinstance(res, BaseException):
if run_managers: if run_managers:
await run_managers[i].on_llm_error(res) await run_managers[i].on_llm_error(
res, response=LLMResult(generations=[])
)
exceptions.append(res) exceptions.append(res)
if exceptions: if exceptions:
if run_managers: if run_managers:

View File

@ -384,8 +384,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
name=config.get("run_name"), name=config.get("run_name"),
batch_size=1, batch_size=1,
) )
try:
generation: Optional[GenerationChunk] = None generation: Optional[GenerationChunk] = None
try:
for chunk in self._stream( for chunk in self._stream(
prompt, stop=stop, run_manager=run_manager, **kwargs prompt, stop=stop, run_manager=run_manager, **kwargs
): ):
@ -396,7 +396,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
generation += chunk generation += chunk
assert generation is not None assert generation is not None
except BaseException as e: except BaseException as e:
run_manager.on_llm_error(e) run_manager.on_llm_error(
e,
response=LLMResult(
generations=[[generation]] if generation else []
),
)
raise e raise e
else: else:
run_manager.on_llm_end(LLMResult(generations=[[generation]])) run_manager.on_llm_end(LLMResult(generations=[[generation]]))
@ -436,8 +441,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
name=config.get("run_name"), name=config.get("run_name"),
batch_size=1, batch_size=1,
) )
try:
generation: Optional[GenerationChunk] = None generation: Optional[GenerationChunk] = None
try:
async for chunk in self._astream( async for chunk in self._astream(
prompt, stop=stop, run_manager=run_manager, **kwargs prompt, stop=stop, run_manager=run_manager, **kwargs
): ):
@ -448,7 +453,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
generation += chunk generation += chunk
assert generation is not None assert generation is not None
except BaseException as e: except BaseException as e:
await run_manager.on_llm_error(e) await run_manager.on_llm_error(
e,
response=LLMResult(
generations=[[generation]] if generation else []
),
)
raise e raise e
else: else:
await run_manager.on_llm_end(LLMResult(generations=[[generation]])) await run_manager.on_llm_end(LLMResult(generations=[[generation]]))
@ -539,7 +549,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
) )
except BaseException as e: except BaseException as e:
for run_manager in run_managers: for run_manager in run_managers:
run_manager.on_llm_error(e) run_manager.on_llm_error(e, response=LLMResult(generations=[]))
raise e raise e
flattened_outputs = output.flatten() flattened_outputs = output.flatten()
for manager, flattened_output in zip(run_managers, flattened_outputs): for manager, flattened_output in zip(run_managers, flattened_outputs):
@ -707,7 +717,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
) )
except BaseException as e: except BaseException as e:
await asyncio.gather( await asyncio.gather(
*[run_manager.on_llm_error(e) for run_manager in run_managers] *[
run_manager.on_llm_error(e, response=LLMResult(generations=[]))
for run_manager in run_managers
]
) )
raise e raise e
flattened_outputs = output.flatten() flattened_outputs = output.flatten()

View File

@ -14,6 +14,7 @@ class BaseFakeCallbackHandler(BaseModel):
starts: int = 0 starts: int = 0
ends: int = 0 ends: int = 0
errors: int = 0 errors: int = 0
errors_args: List[Any] = []
text: int = 0 text: int = 0
ignore_llm_: bool = False ignore_llm_: bool = False
ignore_chain_: bool = False ignore_chain_: bool = False
@ -52,8 +53,9 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
self.llm_ends += 1 self.llm_ends += 1
self.ends += 1 self.ends += 1
def on_llm_error_common(self) -> None: def on_llm_error_common(self, *args: Any, **kwargs: Any) -> None:
self.errors += 1 self.errors += 1
self.errors_args.append({"args": args, "kwargs": kwargs})
def on_llm_new_token_common(self) -> None: def on_llm_new_token_common(self) -> None:
self.llm_streams += 1 self.llm_streams += 1
@ -160,7 +162,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
self.on_llm_error_common() self.on_llm_error_common(*args, **kwargs)
def on_retry( def on_retry(
self, self,
@ -322,7 +324,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
self.on_llm_error_common() self.on_llm_error_common(*args, **kwargs)
async def on_chain_start( async def on_chain_start(
self, self,

View File

@ -45,6 +45,7 @@ class FakeListChatModel(SimpleChatModel):
responses: List responses: List
sleep: Optional[float] = None sleep: Optional[float] = None
i: int = 0 i: int = 0
error_on_chunk_number: Optional[int] = None
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
@ -77,9 +78,15 @@ class FakeListChatModel(SimpleChatModel):
self.i += 1 self.i += 1
else: else:
self.i = 0 self.i = 0
for c in response: for i_c, c in enumerate(response):
if self.sleep is not None: if self.sleep is not None:
time.sleep(self.sleep) time.sleep(self.sleep)
if (
self.error_on_chunk_number is not None
and i_c == self.error_on_chunk_number
):
raise Exception("Fake error")
yield ChatGenerationChunk(message=AIMessageChunk(content=c)) yield ChatGenerationChunk(message=AIMessageChunk(content=c))
async def _astream( async def _astream(
@ -94,9 +101,14 @@ class FakeListChatModel(SimpleChatModel):
self.i += 1 self.i += 1
else: else:
self.i = 0 self.i = 0
for c in response: for i_c, c in enumerate(response):
if self.sleep is not None: if self.sleep is not None:
await asyncio.sleep(self.sleep) await asyncio.sleep(self.sleep)
if (
self.error_on_chunk_number is not None
and i_c == self.error_on_chunk_number
):
raise Exception("Fake error")
yield ChatGenerationChunk(message=AIMessageChunk(content=c)) yield ChatGenerationChunk(message=AIMessageChunk(content=c))
@property @property

View File

@ -60,6 +60,8 @@ class FakeListLLM(LLM):
class FakeStreamingListLLM(FakeListLLM): class FakeStreamingListLLM(FakeListLLM):
"""Fake streaming list LLM for testing purposes.""" """Fake streaming list LLM for testing purposes."""
error_on_chunk_number: Optional[int] = None
def stream( def stream(
self, self,
input: LanguageModelInput, input: LanguageModelInput,
@ -69,9 +71,15 @@ class FakeStreamingListLLM(FakeListLLM):
**kwargs: Any, **kwargs: Any,
) -> Iterator[str]: ) -> Iterator[str]:
result = self.invoke(input, config) result = self.invoke(input, config)
for c in result: for i_c, c in enumerate(result):
if self.sleep is not None: if self.sleep is not None:
time.sleep(self.sleep) time.sleep(self.sleep)
if (
self.error_on_chunk_number is not None
and i_c == self.error_on_chunk_number
):
raise Exception("Fake error")
yield c yield c
async def astream( async def astream(
@ -83,7 +91,13 @@ class FakeStreamingListLLM(FakeListLLM):
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
result = await self.ainvoke(input, config) result = await self.ainvoke(input, config)
for c in result: for i_c, c in enumerate(result):
if self.sleep is not None: if self.sleep is not None:
await asyncio.sleep(self.sleep) await asyncio.sleep(self.sleep)
if (
self.error_on_chunk_number is not None
and i_c == self.error_on_chunk_number
):
raise Exception("Fake error")
yield c yield c

View File

@ -1,8 +1,15 @@
"""Test base chat model.""" """Test base chat model."""
import pytest import pytest
from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.outputs.llm_result import LLMResult
from langchain_core.tracers.context import collect_runs from langchain_core.tracers.context import collect_runs
from tests.unit_tests.fake.callbacks import (
BaseFakeCallbackHandler,
FakeAsyncCallbackHandler,
FakeCallbackHandler,
)
from tests.unit_tests.fake.chat_model import FakeListChatModel from tests.unit_tests.fake.chat_model import FakeListChatModel
@ -69,3 +76,33 @@ async def test_async_batch_size(messages: list, messages_2: list) -> None:
pass pass
assert len(cb.traced_runs) == 1 assert len(cb.traced_runs) == 1
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1 assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
async def test_stream_error_callback() -> None:
message = "test"
def eval_response(callback: BaseFakeCallbackHandler, i: int) -> None:
assert callback.errors == 1
assert len(callback.errors_args) == 1
llm_result: LLMResult = callback.errors_args[0]["kwargs"]["response"]
if i == 0:
assert llm_result.generations == []
else:
assert llm_result.generations[0][0].text == message[:i]
for i in range(0, 2):
llm = FakeListChatModel(
responses=[message],
error_on_chunk_number=i,
)
with pytest.raises(Exception):
cb_async = FakeAsyncCallbackHandler()
async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
pass
eval_response(cb_async, i)
cb_sync = FakeCallbackHandler()
for _ in llm.stream("Dumy message", callbacks=[cb_sync]):
pass
eval_response(cb_sync, i)

View File

@ -1,5 +1,13 @@
import pytest
from langchain_core.outputs.llm_result import LLMResult
from langchain_core.tracers.context import collect_runs from langchain_core.tracers.context import collect_runs
from tests.unit_tests.fake.llm import FakeListLLM from tests.unit_tests.fake.callbacks import (
BaseFakeCallbackHandler,
FakeAsyncCallbackHandler,
FakeCallbackHandler,
)
from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM
def test_batch() -> None: def test_batch() -> None:
@ -75,3 +83,33 @@ async def test_async_batch_size() -> None:
pass pass
assert len(cb.traced_runs) == 1 assert len(cb.traced_runs) == 1
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1 assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
async def test_stream_error_callback() -> None:
message = "test"
def eval_response(callback: BaseFakeCallbackHandler, i: int) -> None:
assert callback.errors == 1
assert len(callback.errors_args) == 1
llm_result: LLMResult = callback.errors_args[0]["kwargs"]["response"]
if i == 0:
assert llm_result.generations == []
else:
assert llm_result.generations[0][0].text == message[:i]
for i in range(0, 2):
llm = FakeStreamingListLLM(
responses=[message],
error_on_chunk_number=i,
)
with pytest.raises(Exception):
cb_async = FakeAsyncCallbackHandler()
async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
pass
eval_response(cb_async, i)
cb_sync = FakeCallbackHandler()
for _ in llm.stream("Dumy message", callbacks=[cb_sync]):
pass
eval_response(cb_sync, i)