mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 22:03:52 +00:00
core[patch]: add response kwarg to on_llm_error
# Dependencies None # Twitter handle @HKydlicek --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
1750cc464d
commit
aa8ae31e5b
@ -75,7 +75,13 @@ class LLMManagerMixin:
|
|||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Run when LLM errors."""
|
"""Run when LLM errors.
|
||||||
|
Args:
|
||||||
|
error (BaseException): The error that occurred.
|
||||||
|
kwargs (Any): Additional keyword arguments.
|
||||||
|
- response (LLMResult): The response which was generated before
|
||||||
|
the error occurred.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class ChainManagerMixin:
|
class ChainManagerMixin:
|
||||||
@ -351,7 +357,13 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
|||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when LLM errors."""
|
"""Run when LLM errors.
|
||||||
|
Args:
|
||||||
|
error (BaseException): The error that occurred.
|
||||||
|
kwargs (Any): Additional keyword arguments.
|
||||||
|
- response (LLMResult): The response which was generated before
|
||||||
|
the error occurred.
|
||||||
|
"""
|
||||||
|
|
||||||
async def on_chain_start(
|
async def on_chain_start(
|
||||||
self,
|
self,
|
||||||
|
@ -623,6 +623,9 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
error (Exception or KeyboardInterrupt): The error.
|
error (Exception or KeyboardInterrupt): The error.
|
||||||
|
kwargs (Any): Additional keyword arguments.
|
||||||
|
- response (LLMResult): The response which was generated before
|
||||||
|
the error occurred.
|
||||||
"""
|
"""
|
||||||
handle_event(
|
handle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
@ -689,6 +692,12 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
error (Exception or KeyboardInterrupt): The error.
|
error (Exception or KeyboardInterrupt): The error.
|
||||||
|
kwargs (Any): Additional keyword arguments.
|
||||||
|
- response (LLMResult): The response which was generated before
|
||||||
|
the error occurred.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
await ahandle_event(
|
await ahandle_event(
|
||||||
self.handlers,
|
self.handlers,
|
||||||
|
@ -223,8 +223,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
name=config.get("run_name"),
|
name=config.get("run_name"),
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
)
|
)
|
||||||
try:
|
|
||||||
generation: Optional[ChatGenerationChunk] = None
|
generation: Optional[ChatGenerationChunk] = None
|
||||||
|
try:
|
||||||
for chunk in self._stream(
|
for chunk in self._stream(
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
):
|
):
|
||||||
@ -235,12 +235,15 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
generation += chunk
|
generation += chunk
|
||||||
assert generation is not None
|
assert generation is not None
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
run_manager.on_llm_error(e)
|
run_manager.on_llm_error(
|
||||||
|
e,
|
||||||
|
response=LLMResult(
|
||||||
|
generations=[[generation]] if generation else []
|
||||||
|
),
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
run_manager.on_llm_end(
|
run_manager.on_llm_end(LLMResult(generations=[[generation]]))
|
||||||
LLMResult(generations=[[generation]]),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def astream(
|
async def astream(
|
||||||
self,
|
self,
|
||||||
@ -277,8 +280,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
name=config.get("run_name"),
|
name=config.get("run_name"),
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
)
|
)
|
||||||
try:
|
|
||||||
generation: Optional[ChatGenerationChunk] = None
|
generation: Optional[ChatGenerationChunk] = None
|
||||||
|
try:
|
||||||
async for chunk in self._astream(
|
async for chunk in self._astream(
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
):
|
):
|
||||||
@ -289,7 +292,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
generation += chunk
|
generation += chunk
|
||||||
assert generation is not None
|
assert generation is not None
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
await run_manager.on_llm_error(e)
|
await run_manager.on_llm_error(
|
||||||
|
e,
|
||||||
|
response=LLMResult(
|
||||||
|
generations=[[generation]] if generation else []
|
||||||
|
),
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
await run_manager.on_llm_end(
|
await run_manager.on_llm_end(
|
||||||
@ -366,7 +374,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
)
|
)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
if run_managers:
|
if run_managers:
|
||||||
run_managers[i].on_llm_error(e)
|
run_managers[i].on_llm_error(e, response=LLMResult(generations=[]))
|
||||||
raise e
|
raise e
|
||||||
flattened_outputs = [
|
flattened_outputs = [
|
||||||
LLMResult(generations=[res.generations], llm_output=res.llm_output)
|
LLMResult(generations=[res.generations], llm_output=res.llm_output)
|
||||||
@ -433,7 +441,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
for i, res in enumerate(results):
|
for i, res in enumerate(results):
|
||||||
if isinstance(res, BaseException):
|
if isinstance(res, BaseException):
|
||||||
if run_managers:
|
if run_managers:
|
||||||
await run_managers[i].on_llm_error(res)
|
await run_managers[i].on_llm_error(
|
||||||
|
res, response=LLMResult(generations=[])
|
||||||
|
)
|
||||||
exceptions.append(res)
|
exceptions.append(res)
|
||||||
if exceptions:
|
if exceptions:
|
||||||
if run_managers:
|
if run_managers:
|
||||||
|
@ -384,8 +384,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
name=config.get("run_name"),
|
name=config.get("run_name"),
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
)
|
)
|
||||||
try:
|
|
||||||
generation: Optional[GenerationChunk] = None
|
generation: Optional[GenerationChunk] = None
|
||||||
|
try:
|
||||||
for chunk in self._stream(
|
for chunk in self._stream(
|
||||||
prompt, stop=stop, run_manager=run_manager, **kwargs
|
prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||||
):
|
):
|
||||||
@ -396,7 +396,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
generation += chunk
|
generation += chunk
|
||||||
assert generation is not None
|
assert generation is not None
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
run_manager.on_llm_error(e)
|
run_manager.on_llm_error(
|
||||||
|
e,
|
||||||
|
response=LLMResult(
|
||||||
|
generations=[[generation]] if generation else []
|
||||||
|
),
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
run_manager.on_llm_end(LLMResult(generations=[[generation]]))
|
run_manager.on_llm_end(LLMResult(generations=[[generation]]))
|
||||||
@ -436,8 +441,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
name=config.get("run_name"),
|
name=config.get("run_name"),
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
)
|
)
|
||||||
try:
|
|
||||||
generation: Optional[GenerationChunk] = None
|
generation: Optional[GenerationChunk] = None
|
||||||
|
try:
|
||||||
async for chunk in self._astream(
|
async for chunk in self._astream(
|
||||||
prompt, stop=stop, run_manager=run_manager, **kwargs
|
prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||||
):
|
):
|
||||||
@ -448,7 +453,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
generation += chunk
|
generation += chunk
|
||||||
assert generation is not None
|
assert generation is not None
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
await run_manager.on_llm_error(e)
|
await run_manager.on_llm_error(
|
||||||
|
e,
|
||||||
|
response=LLMResult(
|
||||||
|
generations=[[generation]] if generation else []
|
||||||
|
),
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
await run_manager.on_llm_end(LLMResult(generations=[[generation]]))
|
await run_manager.on_llm_end(LLMResult(generations=[[generation]]))
|
||||||
@ -539,7 +549,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
)
|
)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
for run_manager in run_managers:
|
for run_manager in run_managers:
|
||||||
run_manager.on_llm_error(e)
|
run_manager.on_llm_error(e, response=LLMResult(generations=[]))
|
||||||
raise e
|
raise e
|
||||||
flattened_outputs = output.flatten()
|
flattened_outputs = output.flatten()
|
||||||
for manager, flattened_output in zip(run_managers, flattened_outputs):
|
for manager, flattened_output in zip(run_managers, flattened_outputs):
|
||||||
@ -707,7 +717,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
)
|
)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[run_manager.on_llm_error(e) for run_manager in run_managers]
|
*[
|
||||||
|
run_manager.on_llm_error(e, response=LLMResult(generations=[]))
|
||||||
|
for run_manager in run_managers
|
||||||
|
]
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
flattened_outputs = output.flatten()
|
flattened_outputs = output.flatten()
|
||||||
|
@ -14,6 +14,7 @@ class BaseFakeCallbackHandler(BaseModel):
|
|||||||
starts: int = 0
|
starts: int = 0
|
||||||
ends: int = 0
|
ends: int = 0
|
||||||
errors: int = 0
|
errors: int = 0
|
||||||
|
errors_args: List[Any] = []
|
||||||
text: int = 0
|
text: int = 0
|
||||||
ignore_llm_: bool = False
|
ignore_llm_: bool = False
|
||||||
ignore_chain_: bool = False
|
ignore_chain_: bool = False
|
||||||
@ -52,8 +53,9 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
|||||||
self.llm_ends += 1
|
self.llm_ends += 1
|
||||||
self.ends += 1
|
self.ends += 1
|
||||||
|
|
||||||
def on_llm_error_common(self) -> None:
|
def on_llm_error_common(self, *args: Any, **kwargs: Any) -> None:
|
||||||
self.errors += 1
|
self.errors += 1
|
||||||
|
self.errors_args.append({"args": args, "kwargs": kwargs})
|
||||||
|
|
||||||
def on_llm_new_token_common(self) -> None:
|
def on_llm_new_token_common(self) -> None:
|
||||||
self.llm_streams += 1
|
self.llm_streams += 1
|
||||||
@ -160,7 +162,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
*args: Any,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_llm_error_common()
|
self.on_llm_error_common(*args, **kwargs)
|
||||||
|
|
||||||
def on_retry(
|
def on_retry(
|
||||||
self,
|
self,
|
||||||
@ -322,7 +324,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
*args: Any,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.on_llm_error_common()
|
self.on_llm_error_common(*args, **kwargs)
|
||||||
|
|
||||||
async def on_chain_start(
|
async def on_chain_start(
|
||||||
self,
|
self,
|
||||||
|
@ -45,6 +45,7 @@ class FakeListChatModel(SimpleChatModel):
|
|||||||
responses: List
|
responses: List
|
||||||
sleep: Optional[float] = None
|
sleep: Optional[float] = None
|
||||||
i: int = 0
|
i: int = 0
|
||||||
|
error_on_chunk_number: Optional[int] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
@ -77,9 +78,15 @@ class FakeListChatModel(SimpleChatModel):
|
|||||||
self.i += 1
|
self.i += 1
|
||||||
else:
|
else:
|
||||||
self.i = 0
|
self.i = 0
|
||||||
for c in response:
|
for i_c, c in enumerate(response):
|
||||||
if self.sleep is not None:
|
if self.sleep is not None:
|
||||||
time.sleep(self.sleep)
|
time.sleep(self.sleep)
|
||||||
|
if (
|
||||||
|
self.error_on_chunk_number is not None
|
||||||
|
and i_c == self.error_on_chunk_number
|
||||||
|
):
|
||||||
|
raise Exception("Fake error")
|
||||||
|
|
||||||
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
||||||
|
|
||||||
async def _astream(
|
async def _astream(
|
||||||
@ -94,9 +101,14 @@ class FakeListChatModel(SimpleChatModel):
|
|||||||
self.i += 1
|
self.i += 1
|
||||||
else:
|
else:
|
||||||
self.i = 0
|
self.i = 0
|
||||||
for c in response:
|
for i_c, c in enumerate(response):
|
||||||
if self.sleep is not None:
|
if self.sleep is not None:
|
||||||
await asyncio.sleep(self.sleep)
|
await asyncio.sleep(self.sleep)
|
||||||
|
if (
|
||||||
|
self.error_on_chunk_number is not None
|
||||||
|
and i_c == self.error_on_chunk_number
|
||||||
|
):
|
||||||
|
raise Exception("Fake error")
|
||||||
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
yield ChatGenerationChunk(message=AIMessageChunk(content=c))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -60,6 +60,8 @@ class FakeListLLM(LLM):
|
|||||||
class FakeStreamingListLLM(FakeListLLM):
|
class FakeStreamingListLLM(FakeListLLM):
|
||||||
"""Fake streaming list LLM for testing purposes."""
|
"""Fake streaming list LLM for testing purposes."""
|
||||||
|
|
||||||
|
error_on_chunk_number: Optional[int] = None
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
input: LanguageModelInput,
|
input: LanguageModelInput,
|
||||||
@ -69,9 +71,15 @@ class FakeStreamingListLLM(FakeListLLM):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[str]:
|
) -> Iterator[str]:
|
||||||
result = self.invoke(input, config)
|
result = self.invoke(input, config)
|
||||||
for c in result:
|
for i_c, c in enumerate(result):
|
||||||
if self.sleep is not None:
|
if self.sleep is not None:
|
||||||
time.sleep(self.sleep)
|
time.sleep(self.sleep)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.error_on_chunk_number is not None
|
||||||
|
and i_c == self.error_on_chunk_number
|
||||||
|
):
|
||||||
|
raise Exception("Fake error")
|
||||||
yield c
|
yield c
|
||||||
|
|
||||||
async def astream(
|
async def astream(
|
||||||
@ -83,7 +91,13 @@ class FakeStreamingListLLM(FakeListLLM):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[str]:
|
) -> AsyncIterator[str]:
|
||||||
result = await self.ainvoke(input, config)
|
result = await self.ainvoke(input, config)
|
||||||
for c in result:
|
for i_c, c in enumerate(result):
|
||||||
if self.sleep is not None:
|
if self.sleep is not None:
|
||||||
await asyncio.sleep(self.sleep)
|
await asyncio.sleep(self.sleep)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.error_on_chunk_number is not None
|
||||||
|
and i_c == self.error_on_chunk_number
|
||||||
|
):
|
||||||
|
raise Exception("Fake error")
|
||||||
yield c
|
yield c
|
||||||
|
@ -1,8 +1,15 @@
|
|||||||
"""Test base chat model."""
|
"""Test base chat model."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from langchain_core.outputs.llm_result import LLMResult
|
||||||
from langchain_core.tracers.context import collect_runs
|
from langchain_core.tracers.context import collect_runs
|
||||||
|
from tests.unit_tests.fake.callbacks import (
|
||||||
|
BaseFakeCallbackHandler,
|
||||||
|
FakeAsyncCallbackHandler,
|
||||||
|
FakeCallbackHandler,
|
||||||
|
)
|
||||||
from tests.unit_tests.fake.chat_model import FakeListChatModel
|
from tests.unit_tests.fake.chat_model import FakeListChatModel
|
||||||
|
|
||||||
|
|
||||||
@ -69,3 +76,33 @@ async def test_async_batch_size(messages: list, messages_2: list) -> None:
|
|||||||
pass
|
pass
|
||||||
assert len(cb.traced_runs) == 1
|
assert len(cb.traced_runs) == 1
|
||||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_stream_error_callback() -> None:
|
||||||
|
message = "test"
|
||||||
|
|
||||||
|
def eval_response(callback: BaseFakeCallbackHandler, i: int) -> None:
|
||||||
|
assert callback.errors == 1
|
||||||
|
assert len(callback.errors_args) == 1
|
||||||
|
llm_result: LLMResult = callback.errors_args[0]["kwargs"]["response"]
|
||||||
|
if i == 0:
|
||||||
|
assert llm_result.generations == []
|
||||||
|
else:
|
||||||
|
assert llm_result.generations[0][0].text == message[:i]
|
||||||
|
|
||||||
|
for i in range(0, 2):
|
||||||
|
llm = FakeListChatModel(
|
||||||
|
responses=[message],
|
||||||
|
error_on_chunk_number=i,
|
||||||
|
)
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
cb_async = FakeAsyncCallbackHandler()
|
||||||
|
async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
|
||||||
|
pass
|
||||||
|
eval_response(cb_async, i)
|
||||||
|
|
||||||
|
cb_sync = FakeCallbackHandler()
|
||||||
|
for _ in llm.stream("Dumy message", callbacks=[cb_sync]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
eval_response(cb_sync, i)
|
||||||
|
@ -1,5 +1,13 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain_core.outputs.llm_result import LLMResult
|
||||||
from langchain_core.tracers.context import collect_runs
|
from langchain_core.tracers.context import collect_runs
|
||||||
from tests.unit_tests.fake.llm import FakeListLLM
|
from tests.unit_tests.fake.callbacks import (
|
||||||
|
BaseFakeCallbackHandler,
|
||||||
|
FakeAsyncCallbackHandler,
|
||||||
|
FakeCallbackHandler,
|
||||||
|
)
|
||||||
|
from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM
|
||||||
|
|
||||||
|
|
||||||
def test_batch() -> None:
|
def test_batch() -> None:
|
||||||
@ -75,3 +83,33 @@ async def test_async_batch_size() -> None:
|
|||||||
pass
|
pass
|
||||||
assert len(cb.traced_runs) == 1
|
assert len(cb.traced_runs) == 1
|
||||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_stream_error_callback() -> None:
|
||||||
|
message = "test"
|
||||||
|
|
||||||
|
def eval_response(callback: BaseFakeCallbackHandler, i: int) -> None:
|
||||||
|
assert callback.errors == 1
|
||||||
|
assert len(callback.errors_args) == 1
|
||||||
|
llm_result: LLMResult = callback.errors_args[0]["kwargs"]["response"]
|
||||||
|
if i == 0:
|
||||||
|
assert llm_result.generations == []
|
||||||
|
else:
|
||||||
|
assert llm_result.generations[0][0].text == message[:i]
|
||||||
|
|
||||||
|
for i in range(0, 2):
|
||||||
|
llm = FakeStreamingListLLM(
|
||||||
|
responses=[message],
|
||||||
|
error_on_chunk_number=i,
|
||||||
|
)
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
cb_async = FakeAsyncCallbackHandler()
|
||||||
|
async for _ in llm.astream("Dummy message", callbacks=[cb_async]):
|
||||||
|
pass
|
||||||
|
eval_response(cb_async, i)
|
||||||
|
|
||||||
|
cb_sync = FakeCallbackHandler()
|
||||||
|
for _ in llm.stream("Dumy message", callbacks=[cb_sync]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
eval_response(cb_sync, i)
|
||||||
|
Loading…
Reference in New Issue
Block a user