mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 21:20:33 +00:00
Add AsyncIteratorCallbackHandler (#2329)
This commit is contained in:
parent
6e4e7d2637
commit
6f39e88a2c
@ -5,6 +5,7 @@ from typing import Generator, Optional
|
|||||||
|
|
||||||
from langchain.callbacks.aim_callback import AimCallbackHandler
|
from langchain.callbacks.aim_callback import AimCallbackHandler
|
||||||
from langchain.callbacks.base import (
|
from langchain.callbacks.base import (
|
||||||
|
AsyncCallbackManager,
|
||||||
BaseCallbackHandler,
|
BaseCallbackHandler,
|
||||||
BaseCallbackManager,
|
BaseCallbackManager,
|
||||||
CallbackManager,
|
CallbackManager,
|
||||||
@ -13,6 +14,7 @@ from langchain.callbacks.clearml_callback import ClearMLCallbackHandler
|
|||||||
from langchain.callbacks.openai_info import OpenAICallbackHandler
|
from langchain.callbacks.openai_info import OpenAICallbackHandler
|
||||||
from langchain.callbacks.shared import SharedCallbackManager
|
from langchain.callbacks.shared import SharedCallbackManager
|
||||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||||
|
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
||||||
from langchain.callbacks.tracers import SharedLangChainTracer
|
from langchain.callbacks.tracers import SharedLangChainTracer
|
||||||
from langchain.callbacks.wandb_callback import WandbCallbackHandler
|
from langchain.callbacks.wandb_callback import WandbCallbackHandler
|
||||||
|
|
||||||
@ -69,12 +71,14 @@ def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CallbackManager",
|
"CallbackManager",
|
||||||
|
"AsyncCallbackManager",
|
||||||
"OpenAICallbackHandler",
|
"OpenAICallbackHandler",
|
||||||
"SharedCallbackManager",
|
"SharedCallbackManager",
|
||||||
"StdOutCallbackHandler",
|
"StdOutCallbackHandler",
|
||||||
"AimCallbackHandler",
|
"AimCallbackHandler",
|
||||||
"WandbCallbackHandler",
|
"WandbCallbackHandler",
|
||||||
"ClearMLCallbackHandler",
|
"ClearMLCallbackHandler",
|
||||||
|
"AsyncIteratorCallbackHandler",
|
||||||
"get_openai_callback",
|
"get_openai_callback",
|
||||||
"set_tracing_callback_manager",
|
"set_tracing_callback_manager",
|
||||||
"set_default_callback_manager",
|
"set_default_callback_manager",
|
||||||
|
66
langchain/callbacks/streaming_aiter.py
Normal file
66
langchain/callbacks/streaming_aiter.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast
|
||||||
|
|
||||||
|
from langchain.callbacks.base import AsyncCallbackHandler
|
||||||
|
from langchain.schema import LLMResult
|
||||||
|
|
||||||
|
# TODO If used by two LLM runs in parallel this won't work as expected
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
|
||||||
|
"""Callback handler that returns an async iterator."""
|
||||||
|
|
||||||
|
queue: asyncio.Queue[str]
|
||||||
|
|
||||||
|
done: asyncio.Event
|
||||||
|
|
||||||
|
@property
|
||||||
|
def always_verbose(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.queue = asyncio.Queue()
|
||||||
|
self.done = asyncio.Event()
|
||||||
|
|
||||||
|
async def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
# If two calls are made in a row, this resets the state
|
||||||
|
self.done.clear()
|
||||||
|
|
||||||
|
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
|
self.queue.put_nowait(token)
|
||||||
|
|
||||||
|
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
|
self.done.set()
|
||||||
|
|
||||||
|
async def on_llm_error(
|
||||||
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
self.done.set()
|
||||||
|
|
||||||
|
# TODO implement the other methods
|
||||||
|
|
||||||
|
async def aiter(self) -> AsyncIterator[str]:
|
||||||
|
while not self.queue.empty() or not self.done.is_set():
|
||||||
|
# Wait for the next token in the queue,
|
||||||
|
# but stop waiting if the done event is set
|
||||||
|
done, _ = await asyncio.wait(
|
||||||
|
[
|
||||||
|
asyncio.ensure_future(self.queue.get()),
|
||||||
|
asyncio.ensure_future(self.done.wait()),
|
||||||
|
],
|
||||||
|
return_when=asyncio.FIRST_COMPLETED,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the value of the first completed task
|
||||||
|
token_or_done = cast(Union[str, Literal[True]], done.pop().result())
|
||||||
|
|
||||||
|
# If the extracted value is the boolean True, the done event was set
|
||||||
|
if token_or_done is True:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Otherwise, the extracted value is a token, which we yield
|
||||||
|
yield token_or_done
|
Loading…
Reference in New Issue
Block a user