Enable streaming for OpenAI LLM (#986)

* Support a callback `on_llm_new_token` that users can implement when
`OpenAI.streaming` is set to `True`
This commit is contained in:
Ankush Gola
2023-02-14 15:06:14 -08:00
committed by GitHub
parent f05f025e41
commit caa8e4742e
26 changed files with 1311 additions and 155 deletions

View File

@@ -3,12 +3,12 @@ from typing import Any, Dict, List, Union
from pydantic import BaseModel
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
"""Fake callback handler for testing."""
class BaseFakeCallbackHandler(BaseModel):
"""Base fake callback handler for testing."""
starts: int = 0
ends: int = 0
@@ -44,10 +44,15 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
chain_ends: int = 0
llm_starts: int = 0
llm_ends: int = 0
llm_streams: int = 0
tool_starts: int = 0
tool_ends: int = 0
agent_ends: int = 0
class FakeCallbackHandler(BaseFakeCallbackHandler, BaseCallbackHandler):
"""Fake callback handler for testing."""
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
@@ -55,6 +60,10 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
self.llm_starts += 1
self.starts += 1
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
self.llm_streams += 1
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.llm_ends += 1
@@ -110,3 +119,74 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
"""Run when agent ends running."""
self.agent_ends += 1
self.ends += 1
class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
"""Fake async callback handler for testing."""
async def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
self.llm_starts += 1
self.starts += 1
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
self.llm_streams += 1
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.llm_ends += 1
self.ends += 1
async def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors."""
self.errors += 1
async def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
self.chain_starts += 1
self.starts += 1
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
self.chain_ends += 1
self.ends += 1
async def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when chain errors."""
self.errors += 1
async def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self.tool_starts += 1
self.starts += 1
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
self.tool_ends += 1
self.ends += 1
async def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when tool errors."""
self.errors += 1
async def on_text(self, text: str, **kwargs: Any) -> None:
"""Run when agent is ending."""
self.text += 1
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
self.agent_ends += 1
self.ends += 1

View File

@@ -1,13 +1,24 @@
"""Test CallbackManager."""
from typing import Tuple
from langchain.callbacks.base import BaseCallbackManager, CallbackManager
import pytest
from langchain.callbacks.base import (
AsyncCallbackManager,
BaseCallbackManager,
CallbackManager,
)
from langchain.callbacks.shared import SharedCallbackManager
from langchain.schema import AgentAction, AgentFinish, LLMResult
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
from tests.unit_tests.callbacks.fake_callback_handler import (
BaseFakeCallbackHandler,
FakeAsyncCallbackHandler,
FakeCallbackHandler,
)
def _test_callback_manager(
manager: BaseCallbackManager, *handlers: FakeCallbackHandler
manager: BaseCallbackManager, *handlers: BaseFakeCallbackHandler
) -> None:
"""Test the CallbackManager."""
manager.on_llm_start({}, [])
@@ -20,6 +31,27 @@ def _test_callback_manager(
manager.on_tool_end("")
manager.on_tool_error(Exception())
manager.on_agent_finish(AgentFinish(log="", return_values={}))
_check_num_calls(handlers)
async def _test_callback_manager_async(
manager: AsyncCallbackManager, *handlers: BaseFakeCallbackHandler
) -> None:
"""Test the CallbackManager."""
await manager.on_llm_start({}, [])
await manager.on_llm_end(LLMResult(generations=[]))
await manager.on_llm_error(Exception())
await manager.on_chain_start({"name": "foo"}, {})
await manager.on_chain_end({})
await manager.on_chain_error(Exception())
await manager.on_tool_start({}, AgentAction("", "", ""))
await manager.on_tool_end("")
await manager.on_tool_error(Exception())
await manager.on_agent_finish(AgentFinish(log="", return_values={}))
_check_num_calls(handlers)
def _check_num_calls(handlers: Tuple[BaseFakeCallbackHandler, ...]) -> None:
for handler in handlers:
if handler.always_verbose:
assert handler.starts == 3
@@ -128,3 +160,21 @@ def test_shared_callback_manager() -> None:
manager1.add_handler(handler1)
manager2.add_handler(handler2)
_test_callback_manager(manager1, handler1, handler2)
@pytest.mark.asyncio
async def test_async_callback_manager() -> None:
"""Test the AsyncCallbackManager."""
handler1 = FakeAsyncCallbackHandler(always_verbose_=True)
handler2 = FakeAsyncCallbackHandler()
manager = AsyncCallbackManager([handler1, handler2])
await _test_callback_manager_async(manager, handler1, handler2)
@pytest.mark.asyncio
async def test_async_callback_manager_sync_handler() -> None:
"""Test the AsyncCallbackManager."""
handler1 = FakeCallbackHandler(always_verbose_=True)
handler2 = FakeAsyncCallbackHandler()
manager = AsyncCallbackManager([handler1, handler2])
await _test_callback_manager_async(manager, handler1, handler2)