diff --git a/libs/core/langchain_core/callbacks/base.py b/libs/core/langchain_core/callbacks/base.py index 14078755f4f..ed30e50ff14 100644 --- a/libs/core/langchain_core/callbacks/base.py +++ b/libs/core/langchain_core/callbacks/base.py @@ -75,7 +75,13 @@ class LLMManagerMixin: parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: - """Run when LLM errors.""" + """Run when LLM errors. + Args: + error (BaseException): The error that occurred. + kwargs (Any): Additional keyword arguments. + - response (LLMResult): The response which was generated before + the error occurred. + """ class ChainManagerMixin: @@ -351,7 +357,13 @@ class AsyncCallbackHandler(BaseCallbackHandler): tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: - """Run when LLM errors.""" + """Run when LLM errors. + Args: + error (BaseException): The error that occurred. + kwargs (Any): Additional keyword arguments. + - response (LLMResult): The response which was generated before + the error occurred. + """ async def on_chain_start( self, diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 402900c6321..b1bb0119279 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -623,6 +623,9 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): Args: error (Exception or KeyboardInterrupt): The error. + kwargs (Any): Additional keyword arguments. + - response (LLMResult): The response which was generated before + the error occurred. """ handle_event( self.handlers, @@ -689,6 +692,12 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): Args: error (Exception or KeyboardInterrupt): The error. + kwargs (Any): Additional keyword arguments. + - response (LLMResult): The response which was generated before + the error occurred. + + + """ await ahandle_event( self.handlers, diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 69bede4a12a..24bb4114cb4 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -223,8 +223,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): name=config.get("run_name"), batch_size=1, ) + generation: Optional[ChatGenerationChunk] = None try: - generation: Optional[ChatGenerationChunk] = None for chunk in self._stream( messages, stop=stop, run_manager=run_manager, **kwargs ): @@ -235,12 +235,15 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): generation += chunk assert generation is not None except BaseException as e: - run_manager.on_llm_error(e) + run_manager.on_llm_error( + e, + response=LLMResult( + generations=[[generation]] if generation else [] + ), + ) raise e else: - run_manager.on_llm_end( - LLMResult(generations=[[generation]]), - ) + run_manager.on_llm_end(LLMResult(generations=[[generation]])) async def astream( self, @@ -277,8 +280,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): name=config.get("run_name"), batch_size=1, ) + generation: Optional[ChatGenerationChunk] = None try: - generation: Optional[ChatGenerationChunk] = None async for chunk in self._astream( messages, stop=stop, run_manager=run_manager, **kwargs ): @@ -289,7 +292,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): generation += chunk assert generation is not None except BaseException as e: - await run_manager.on_llm_error(e) + await run_manager.on_llm_error( + e, + response=LLMResult( + generations=[[generation]] if generation else [] + ), + ) raise e else: await run_manager.on_llm_end( @@ -366,7 +374,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): ) except BaseException as e: if run_managers: - run_managers[i].on_llm_error(e) + run_managers[i].on_llm_error(e, response=LLMResult(generations=[])) raise e flattened_outputs = [ LLMResult(generations=[res.generations], llm_output=res.llm_output) @@ -433,7 +441,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): for i, res in enumerate(results): if isinstance(res, BaseException): if run_managers: - await run_managers[i].on_llm_error(res) + await run_managers[i].on_llm_error( + res, response=LLMResult(generations=[]) + ) exceptions.append(res) if exceptions: if run_managers: diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index dea4375b7f4..e0e830d10be 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -384,8 +384,8 @@ class BaseLLM(BaseLanguageModel[str], ABC): name=config.get("run_name"), batch_size=1, ) + generation: Optional[GenerationChunk] = None try: - generation: Optional[GenerationChunk] = None for chunk in self._stream( prompt, stop=stop, run_manager=run_manager, **kwargs ): @@ -396,7 +396,12 @@ class BaseLLM(BaseLanguageModel[str], ABC): generation += chunk assert generation is not None except BaseException as e: - run_manager.on_llm_error(e) + run_manager.on_llm_error( + e, + response=LLMResult( + generations=[[generation]] if generation else [] + ), + ) raise e else: run_manager.on_llm_end(LLMResult(generations=[[generation]])) @@ -436,8 +441,8 @@ class BaseLLM(BaseLanguageModel[str], ABC): name=config.get("run_name"), batch_size=1, ) + generation: Optional[GenerationChunk] = None try: - generation: Optional[GenerationChunk] = None async for chunk in self._astream( prompt, stop=stop, run_manager=run_manager, **kwargs ): @@ -448,7 +453,12 @@ class BaseLLM(BaseLanguageModel[str], ABC): generation += chunk assert generation is not None except BaseException as e: - await run_manager.on_llm_error(e) + await run_manager.on_llm_error( + e, + response=LLMResult( + generations=[[generation]] if generation else [] + ), + ) raise e else: await run_manager.on_llm_end(LLMResult(generations=[[generation]])) @@ -539,7 +549,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): ) except BaseException as e: for run_manager in run_managers: - run_manager.on_llm_error(e) + run_manager.on_llm_error(e, response=LLMResult(generations=[])) raise e flattened_outputs = output.flatten() for manager, flattened_output in zip(run_managers, flattened_outputs): @@ -707,7 +717,10 @@ class BaseLLM(BaseLanguageModel[str], ABC): ) except BaseException as e: await asyncio.gather( - *[run_manager.on_llm_error(e) for run_manager in run_managers] + *[ + run_manager.on_llm_error(e, response=LLMResult(generations=[])) + for run_manager in run_managers + ] ) raise e flattened_outputs = output.flatten() diff --git a/libs/core/tests/unit_tests/fake/callbacks.py b/libs/core/tests/unit_tests/fake/callbacks.py index 2a2af92269f..b2bef343fff 100644 --- a/libs/core/tests/unit_tests/fake/callbacks.py +++ b/libs/core/tests/unit_tests/fake/callbacks.py @@ -14,6 +14,7 @@ class BaseFakeCallbackHandler(BaseModel): starts: int = 0 ends: int = 0 errors: int = 0 + errors_args: List[Any] = [] text: int = 0 ignore_llm_: bool = False ignore_chain_: bool = False @@ -52,8 +53,9 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler): self.llm_ends += 1 self.ends += 1 - def on_llm_error_common(self) -> None: + def on_llm_error_common(self, *args: Any, **kwargs: Any) -> None: self.errors += 1 + self.errors_args.append({"args": args, "kwargs": kwargs}) def on_llm_new_token_common(self) -> None: self.llm_streams += 1 @@ -160,7 +162,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): *args: Any, **kwargs: Any, ) -> Any: - self.on_llm_error_common() + self.on_llm_error_common(*args, **kwargs) def on_retry( self, @@ -322,7 +324,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi *args: Any, **kwargs: Any, ) -> None: - self.on_llm_error_common() + self.on_llm_error_common(*args, **kwargs) async def on_chain_start( self, diff --git a/libs/core/tests/unit_tests/fake/chat_model.py b/libs/core/tests/unit_tests/fake/chat_model.py index e1268ad4fd3..717ab02533f 100644 --- a/libs/core/tests/unit_tests/fake/chat_model.py +++ b/libs/core/tests/unit_tests/fake/chat_model.py @@ -45,6 +45,7 @@ class FakeListChatModel(SimpleChatModel): responses: List sleep: Optional[float] = None i: int = 0 + error_on_chunk_number: Optional[int] = None @property def _llm_type(self) -> str: @@ -77,9 +78,15 @@ class FakeListChatModel(SimpleChatModel): self.i += 1 else: self.i = 0 - for c in response: + for i_c, c in enumerate(response): if self.sleep is not None: time.sleep(self.sleep) + if ( + self.error_on_chunk_number is not None + and i_c == self.error_on_chunk_number + ): + raise Exception("Fake error") + yield ChatGenerationChunk(message=AIMessageChunk(content=c)) async def _astream( @@ -94,9 +101,14 @@ class FakeListChatModel(SimpleChatModel): self.i += 1 else: self.i = 0 - for c in response: + for i_c, c in enumerate(response): if self.sleep is not None: await asyncio.sleep(self.sleep) + if ( + self.error_on_chunk_number is not None + and i_c == self.error_on_chunk_number + ): + raise Exception("Fake error") yield ChatGenerationChunk(message=AIMessageChunk(content=c)) @property diff --git a/libs/core/tests/unit_tests/fake/llm.py b/libs/core/tests/unit_tests/fake/llm.py index 1ebff8d8ca1..165e5b3d2df 100644 --- a/libs/core/tests/unit_tests/fake/llm.py +++ b/libs/core/tests/unit_tests/fake/llm.py @@ -60,6 +60,8 @@ class FakeListLLM(LLM): class FakeStreamingListLLM(FakeListLLM): """Fake streaming list LLM for testing purposes.""" + error_on_chunk_number: Optional[int] = None + def stream( self, input: LanguageModelInput, @@ -69,9 +71,15 @@ class FakeStreamingListLLM(FakeListLLM): **kwargs: Any, ) -> Iterator[str]: result = self.invoke(input, config) - for c in result: + for i_c, c in enumerate(result): if self.sleep is not None: time.sleep(self.sleep) + + if ( + self.error_on_chunk_number is not None + and i_c == self.error_on_chunk_number + ): + raise Exception("Fake error") yield c async def astream( @@ -83,7 +91,13 @@ class FakeStreamingListLLM(FakeListLLM): **kwargs: Any, ) -> AsyncIterator[str]: result = await self.ainvoke(input, config) - for c in result: + for i_c, c in enumerate(result): if self.sleep is not None: await asyncio.sleep(self.sleep) + + if ( + self.error_on_chunk_number is not None + and i_c == self.error_on_chunk_number + ): + raise Exception("Fake error") yield c diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index 0f406a06aef..24c49f79a3f 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -1,8 +1,15 @@ """Test base chat model.""" + import pytest from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.outputs.llm_result import LLMResult from langchain_core.tracers.context import collect_runs +from tests.unit_tests.fake.callbacks import ( + BaseFakeCallbackHandler, + FakeAsyncCallbackHandler, + FakeCallbackHandler, +) from tests.unit_tests.fake.chat_model import FakeListChatModel @@ -69,3 +76,33 @@ async def test_async_batch_size(messages: list, messages_2: list) -> None: pass assert len(cb.traced_runs) == 1 assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1 + + +async def test_stream_error_callback() -> None: + message = "test" + + def eval_response(callback: BaseFakeCallbackHandler, i: int) -> None: + assert callback.errors == 1 + assert len(callback.errors_args) == 1 + llm_result: LLMResult = callback.errors_args[0]["kwargs"]["response"] + if i == 0: + assert llm_result.generations == [] + else: + assert llm_result.generations[0][0].text == message[:i] + + for i in range(0, 2): + llm = FakeListChatModel( + responses=[message], + error_on_chunk_number=i, + ) + with pytest.raises(Exception): + cb_async = FakeAsyncCallbackHandler() + async for _ in llm.astream("Dummy message", callbacks=[cb_async]): + pass + eval_response(cb_async, i) + + cb_sync = FakeCallbackHandler() + for _ in llm.stream("Dumy message", callbacks=[cb_sync]): + pass + + eval_response(cb_sync, i) diff --git a/libs/core/tests/unit_tests/language_models/llms/test_base.py b/libs/core/tests/unit_tests/language_models/llms/test_base.py index 37b81a0ed22..a6e866cf976 100644 --- a/libs/core/tests/unit_tests/language_models/llms/test_base.py +++ b/libs/core/tests/unit_tests/language_models/llms/test_base.py @@ -1,5 +1,13 @@ +import pytest + +from langchain_core.outputs.llm_result import LLMResult from langchain_core.tracers.context import collect_runs -from tests.unit_tests.fake.llm import FakeListLLM +from tests.unit_tests.fake.callbacks import ( + BaseFakeCallbackHandler, + FakeAsyncCallbackHandler, + FakeCallbackHandler, +) +from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM def test_batch() -> None: @@ -75,3 +83,33 @@ async def test_async_batch_size() -> None: pass assert len(cb.traced_runs) == 1 assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1 + + +async def test_stream_error_callback() -> None: + message = "test" + + def eval_response(callback: BaseFakeCallbackHandler, i: int) -> None: + assert callback.errors == 1 + assert len(callback.errors_args) == 1 + llm_result: LLMResult = callback.errors_args[0]["kwargs"]["response"] + if i == 0: + assert llm_result.generations == [] + else: + assert llm_result.generations[0][0].text == message[:i] + + for i in range(0, 2): + llm = FakeStreamingListLLM( + responses=[message], + error_on_chunk_number=i, + ) + with pytest.raises(Exception): + cb_async = FakeAsyncCallbackHandler() + async for _ in llm.astream("Dummy message", callbacks=[cb_async]): + pass + eval_response(cb_async, i) + + cb_sync = FakeCallbackHandler() + for _ in llm.stream("Dumy message", callbacks=[cb_sync]): + pass + + eval_response(cb_sync, i)