mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-09 06:53:59 +00:00
Enable streaming for OpenAI LLM (#986)
* Support a callback `on_llm_new_token` that users can implement when `OpenAI.streaming` is set to `True`
This commit is contained in:
@@ -3,12 +3,12 @@ from typing import Any, Dict, List, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
|
||||
"""Fake callback handler for testing."""
|
||||
class BaseFakeCallbackHandler(BaseModel):
|
||||
"""Base fake callback handler for testing."""
|
||||
|
||||
starts: int = 0
|
||||
ends: int = 0
|
||||
@@ -44,10 +44,15 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
|
||||
chain_ends: int = 0
|
||||
llm_starts: int = 0
|
||||
llm_ends: int = 0
|
||||
llm_streams: int = 0
|
||||
tool_starts: int = 0
|
||||
tool_ends: int = 0
|
||||
agent_ends: int = 0
|
||||
|
||||
|
||||
class FakeCallbackHandler(BaseFakeCallbackHandler, BaseCallbackHandler):
|
||||
"""Fake callback handler for testing."""
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
@@ -55,6 +60,10 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
|
||||
self.llm_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
self.llm_streams += 1
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.llm_ends += 1
|
||||
@@ -110,3 +119,74 @@ class FakeCallbackHandler(BaseModel, BaseCallbackHandler):
|
||||
"""Run when agent ends running."""
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
|
||||
class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
|
||||
"""Fake async callback handler for testing."""
|
||||
|
||||
async def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
self.llm_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
self.llm_streams += 1
|
||||
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.llm_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
async def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.errors += 1
|
||||
|
||||
async def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.chain_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
async def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.errors += 1
|
||||
|
||||
async def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
async def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.errors += 1
|
||||
|
||||
async def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run when agent is ending."""
|
||||
self.text += 1
|
||||
|
||||
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
|
@@ -1,13 +1,24 @@
|
||||
"""Test CallbackManager."""
|
||||
from typing import Tuple
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackManager, CallbackManager
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.base import (
|
||||
AsyncCallbackManager,
|
||||
BaseCallbackManager,
|
||||
CallbackManager,
|
||||
)
|
||||
from langchain.callbacks.shared import SharedCallbackManager
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import (
|
||||
BaseFakeCallbackHandler,
|
||||
FakeAsyncCallbackHandler,
|
||||
FakeCallbackHandler,
|
||||
)
|
||||
|
||||
|
||||
def _test_callback_manager(
|
||||
manager: BaseCallbackManager, *handlers: FakeCallbackHandler
|
||||
manager: BaseCallbackManager, *handlers: BaseFakeCallbackHandler
|
||||
) -> None:
|
||||
"""Test the CallbackManager."""
|
||||
manager.on_llm_start({}, [])
|
||||
@@ -20,6 +31,27 @@ def _test_callback_manager(
|
||||
manager.on_tool_end("")
|
||||
manager.on_tool_error(Exception())
|
||||
manager.on_agent_finish(AgentFinish(log="", return_values={}))
|
||||
_check_num_calls(handlers)
|
||||
|
||||
|
||||
async def _test_callback_manager_async(
|
||||
manager: AsyncCallbackManager, *handlers: BaseFakeCallbackHandler
|
||||
) -> None:
|
||||
"""Test the CallbackManager."""
|
||||
await manager.on_llm_start({}, [])
|
||||
await manager.on_llm_end(LLMResult(generations=[]))
|
||||
await manager.on_llm_error(Exception())
|
||||
await manager.on_chain_start({"name": "foo"}, {})
|
||||
await manager.on_chain_end({})
|
||||
await manager.on_chain_error(Exception())
|
||||
await manager.on_tool_start({}, AgentAction("", "", ""))
|
||||
await manager.on_tool_end("")
|
||||
await manager.on_tool_error(Exception())
|
||||
await manager.on_agent_finish(AgentFinish(log="", return_values={}))
|
||||
_check_num_calls(handlers)
|
||||
|
||||
|
||||
def _check_num_calls(handlers: Tuple[BaseFakeCallbackHandler, ...]) -> None:
|
||||
for handler in handlers:
|
||||
if handler.always_verbose:
|
||||
assert handler.starts == 3
|
||||
@@ -128,3 +160,21 @@ def test_shared_callback_manager() -> None:
|
||||
manager1.add_handler(handler1)
|
||||
manager2.add_handler(handler2)
|
||||
_test_callback_manager(manager1, handler1, handler2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_callback_manager() -> None:
|
||||
"""Test the AsyncCallbackManager."""
|
||||
handler1 = FakeAsyncCallbackHandler(always_verbose_=True)
|
||||
handler2 = FakeAsyncCallbackHandler()
|
||||
manager = AsyncCallbackManager([handler1, handler2])
|
||||
await _test_callback_manager_async(manager, handler1, handler2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_callback_manager_sync_handler() -> None:
|
||||
"""Test the AsyncCallbackManager."""
|
||||
handler1 = FakeCallbackHandler(always_verbose_=True)
|
||||
handler2 = FakeAsyncCallbackHandler()
|
||||
manager = AsyncCallbackManager([handler1, handler2])
|
||||
await _test_callback_manager_async(manager, handler1, handler2)
|
||||
|
Reference in New Issue
Block a user