mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-28 10:39:23 +00:00
Support using async callback handlers with sync callback manager (#10945)
The current behaviour just calls the handler without awaiting the coroutine, which results in exceptions/warnings, and obviously doesn't actually execute whatever the callback handler does <!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
48a04aed75
commit
77ce9ed6f1
@ -5,12 +5,14 @@ import functools
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
@ -370,37 +372,84 @@ def _handle_event(
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Generic event handler for CallbackManager."""
|
||||
message_strings: Optional[List[str]] = None
|
||||
for handler in handlers:
|
||||
try:
|
||||
if ignore_condition_name is None or not getattr(
|
||||
handler, ignore_condition_name
|
||||
):
|
||||
getattr(handler, event_name)(*args, **kwargs)
|
||||
except NotImplementedError as e:
|
||||
if event_name == "on_chat_model_start":
|
||||
if message_strings is None:
|
||||
message_strings = [get_buffer_string(m) for m in args[1]]
|
||||
_handle_event(
|
||||
[handler],
|
||||
"on_llm_start",
|
||||
"ignore_llm",
|
||||
args[0],
|
||||
message_strings,
|
||||
*args[2:],
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
coros: List[Coroutine[Any, Any, Any]] = []
|
||||
|
||||
try:
|
||||
message_strings: Optional[List[str]] = None
|
||||
for handler in handlers:
|
||||
try:
|
||||
if ignore_condition_name is None or not getattr(
|
||||
handler, ignore_condition_name
|
||||
):
|
||||
event = getattr(handler, event_name)(*args, **kwargs)
|
||||
if asyncio.iscoroutine(event):
|
||||
coros.append(event)
|
||||
except NotImplementedError as e:
|
||||
if event_name == "on_chat_model_start":
|
||||
if message_strings is None:
|
||||
message_strings = [get_buffer_string(m) for m in args[1]]
|
||||
_handle_event(
|
||||
[handler],
|
||||
"on_llm_start",
|
||||
"ignore_llm",
|
||||
args[0],
|
||||
message_strings,
|
||||
*args[2:],
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
handler_name = handler.__class__.__name__
|
||||
logger.warning(
|
||||
f"NotImplementedError in {handler_name}.{event_name}"
|
||||
f" callback: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"NotImplementedError in {handler.__class__.__name__}.{event_name}"
|
||||
f" callback: {e}"
|
||||
f"Error in {handler.__class__.__name__}.{event_name} callback: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error in {handler.__class__.__name__}.{event_name} callback: {e}"
|
||||
)
|
||||
if handler.raise_error:
|
||||
raise e
|
||||
if handler.raise_error:
|
||||
raise e
|
||||
finally:
|
||||
if coros:
|
||||
try:
|
||||
# Raises RuntimeError if there is no current event loop.
|
||||
asyncio.get_running_loop()
|
||||
loop_running = True
|
||||
except RuntimeError:
|
||||
loop_running = False
|
||||
|
||||
if loop_running:
|
||||
# If we try to submit this coroutine to the running loop
|
||||
# we end up in a deadlock, as we'd have gotten here from a
|
||||
# running coroutine, which we cannot interrupt to run this one.
|
||||
# The solution is to create a new loop in a new thread.
|
||||
with ThreadPoolExecutor(1) as executor:
|
||||
executor.submit(_run_coros, coros).result()
|
||||
else:
|
||||
_run_coros(coros)
|
||||
|
||||
|
||||
def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
|
||||
if hasattr(asyncio, "Runner"):
|
||||
# Python 3.11+
|
||||
# Run the coroutines in a new event loop, taking care to
|
||||
# - install signal handlers
|
||||
# - run pending tasks scheduled by `coros`
|
||||
# - close asyncgens and executors
|
||||
# - close the loop
|
||||
with asyncio.Runner() as runner:
|
||||
# Run the coroutine, get the result
|
||||
for coro in coros:
|
||||
runner.run(coro)
|
||||
|
||||
# Run pending tasks scheduled by coros until they are all done
|
||||
while pending := asyncio.all_tasks(runner.get_loop()):
|
||||
runner.run(asyncio.wait(pending))
|
||||
else:
|
||||
# Before Python 3.11 we need to run each coroutine in a new event loop
|
||||
# as the Runner api is not available.
|
||||
for coro in coros:
|
||||
asyncio.run(coro)
|
||||
|
||||
|
||||
async def _ahandle_event_for_handler(
|
||||
|
@ -92,6 +92,27 @@ def test_callback_manager() -> None:
|
||||
_test_callback_manager(manager, handler1, handler2)
|
||||
|
||||
|
||||
def test_callback_manager_with_async() -> None:
|
||||
"""Test the CallbackManager."""
|
||||
handler1 = FakeCallbackHandler()
|
||||
handler2 = FakeCallbackHandler()
|
||||
handler3 = FakeAsyncCallbackHandler()
|
||||
handler4 = FakeAsyncCallbackHandler()
|
||||
manager = CallbackManager(handlers=[handler1, handler2, handler3, handler4])
|
||||
_test_callback_manager(manager, handler1, handler2, handler3, handler4)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_manager_with_async_with_running_loop() -> None:
|
||||
"""Test the CallbackManager."""
|
||||
handler1 = FakeCallbackHandler()
|
||||
handler2 = FakeCallbackHandler()
|
||||
handler3 = FakeAsyncCallbackHandler()
|
||||
handler4 = FakeAsyncCallbackHandler()
|
||||
manager = CallbackManager(handlers=[handler1, handler2, handler3, handler4])
|
||||
_test_callback_manager(manager, handler1, handler2, handler3, handler4)
|
||||
|
||||
|
||||
def test_ignore_llm() -> None:
|
||||
"""Test ignore llm param for callback handlers."""
|
||||
handler1 = FakeCallbackHandler(ignore_llm_=True)
|
||||
|
Loading…
Reference in New Issue
Block a user