Support Streaming Tokens from OpenAI (#364)

https://github.com/hwchase17/langchain/issues/363

@hwchase17 how much does this make you want to cry?
This commit is contained in:
mrbean 2022-12-17 10:02:58 -05:00 committed by GitHub
parent fe6695b9e7
commit 50257fce59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 44 additions and 1 deletions

View File

@ -1,6 +1,6 @@
"""Wrapper around OpenAI APIs.""" """Wrapper around OpenAI APIs."""
import sys import sys
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, Generator, List, Mapping, Optional
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator
@ -160,6 +160,30 @@ class OpenAI(LLM, BaseModel):
generations=generations, llm_output={"token_usage": token_usage} generations=generations, llm_output={"token_usage": token_usage}
) )
def stream(self, prompt: str) -> Generator:
"""Call OpenAI with streaming flag and return the resulting generator.
Args:
prompt: The prompts to pass into the model.
Returns:
A generator representing the stream of tokens from OpenAI.
Example:
.. code-block:: python
generator = openai.stream("Tell me a joke.")
for token in generator:
yield token
"""
params = self._default_params
if params["best_of"] != 1:
raise ValueError("OpenAI only supports best_of == 1 for streaming")
params["stream"] = True
generator = self.client.create(model=self.model_name, prompt=prompt, **params)
return generator
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters.""" """Get the identifying parameters."""

View File

@ -1,6 +1,7 @@
"""Test OpenAI API wrapper.""" """Test OpenAI API wrapper."""
from pathlib import Path from pathlib import Path
from typing import Generator
import pytest import pytest
@ -55,3 +56,21 @@ def test_saving_loading_llm(tmp_path: Path) -> None:
llm.save(file_path=tmp_path / "openai.yaml") llm.save(file_path=tmp_path / "openai.yaml")
loaded_llm = load_llm(tmp_path / "openai.yaml") loaded_llm = load_llm(tmp_path / "openai.yaml")
assert loaded_llm == llm assert loaded_llm == llm
def test_openai_streaming() -> None:
"""Test streaming tokens from OpenAI."""
llm = OpenAI(max_tokens=10)
generator = llm.stream("I'm Pickle Rick")
assert isinstance(generator, Generator)
for token in generator:
assert isinstance(token["choices"][0]["text"], str)
def test_openai_streaming_error() -> None:
"""Test error handling in stream."""
llm = OpenAI(best_of=2)
with pytest.raises(ValueError):
llm.stream("I'm Pickle Rick")