mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-03 20:16:52 +00:00
Enable streaming for OpenAI LLM (#986)
* Support a callback `on_llm_new_token` that users can implement when `OpenAI.streaming` is set to `True`
This commit is contained in:
@@ -5,9 +5,11 @@ from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.llms.loading import load_llm
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.schema import LLMResult
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
def test_openai_call() -> None:
|
||||
@@ -77,9 +79,66 @@ def test_openai_streaming_error() -> None:
|
||||
llm.stream("I'm Pickle Rick")
|
||||
|
||||
|
||||
def test_openai_streaming_best_of_error() -> None:
|
||||
"""Test validation for streaming fails if best_of is not 1."""
|
||||
with pytest.raises(ValueError):
|
||||
OpenAI(best_of=2, streaming=True)
|
||||
|
||||
|
||||
def test_openai_streaming_n_error() -> None:
|
||||
"""Test validation for streaming fails if n is not 1."""
|
||||
with pytest.raises(ValueError):
|
||||
OpenAI(n=2, streaming=True)
|
||||
|
||||
|
||||
def test_openai_streaming_multiple_prompts_error() -> None:
|
||||
"""Test validation for streaming fails if multiple prompts are given."""
|
||||
with pytest.raises(ValueError):
|
||||
OpenAI(streaming=True).generate(["I'm Pickle Rick", "I'm Pickle Rick"])
|
||||
|
||||
|
||||
def test_openai_streaming_call() -> None:
|
||||
"""Test valid call to openai."""
|
||||
llm = OpenAI(max_tokens=10, streaming=True)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_openai_streaming_callback() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
llm = OpenAI(
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
callback_manager=callback_manager,
|
||||
verbose=True,
|
||||
)
|
||||
llm("Write me a sentence with 100 words.")
|
||||
assert callback_handler.llm_streams == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_async_generate() -> None:
|
||||
"""Test async generation."""
|
||||
llm = OpenAI(max_tokens=10)
|
||||
output = await llm.agenerate(["Hello, how are you?"])
|
||||
assert isinstance(output, LLMResult)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_async_streaming_callback() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
llm = OpenAI(
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
callback_manager=callback_manager,
|
||||
verbose=True,
|
||||
)
|
||||
result = await llm.agenerate(["Write me a sentence with 100 words."])
|
||||
assert callback_handler.llm_streams == 10
|
||||
assert isinstance(result, LLMResult)
|
||||
|
Reference in New Issue
Block a user