add async and streaming support to OpenAIChat (#1378)

title says it all
This commit is contained in:
Ankush Gola 2023-03-01 21:55:43 -08:00 committed by GitHub
parent cfed0497ac
commit fe30be6fba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 321 additions and 140 deletions

View File

@ -1,7 +1,6 @@
{ {
"cells": [ "cells": [
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"id": "f6574496-b360-4ffa-9523-7fd34a590164", "id": "f6574496-b360-4ffa-9523-7fd34a590164",
"metadata": {}, "metadata": {},
@ -10,14 +9,14 @@
"\n", "\n",
"LangChain provides async support for LLMs by leveraging the [asyncio](https://docs.python.org/3/library/asyncio.html) library.\n", "LangChain provides async support for LLMs by leveraging the [asyncio](https://docs.python.org/3/library/asyncio.html) library.\n",
"\n", "\n",
"Async support is particularly useful for calling multiple LLMs concurrently, as these calls are network-bound. Currently, only `OpenAI` and `PromptLayerOpenAI` is supported, but async support for other LLMs is on the roadmap.\n", "Async support is particularly useful for calling multiple LLMs concurrently, as these calls are network-bound. Currently, only `OpenAI` `OpenAIChat`, and `PromptLayerOpenAI` are supported, but async support for other LLMs is on the roadmap.\n",
"\n", "\n",
"You can use the `agenerate` method to call an OpenAI LLM asynchronously." "You can use the `agenerate` method to call an OpenAI LLM asynchronously."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 1,
"id": "5e49e96c-0f88-466d-b3d3-ea0966bdf19e", "id": "5e49e96c-0f88-466d-b3d3-ea0966bdf19e",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -29,64 +28,66 @@
"text": [ "text": [
"\n", "\n",
"\n", "\n",
"I'm doing well. How about you?\n", "As an AI language model, I don't have feelings like humans, but I'm functioning properly. How may I assist you?\n",
"\n", "\n",
"\n", "\n",
"I'm doing well, thank you. How about you?\n", "I'm an AI language model, so I don't have emotions, but I'm functioning properly. How may I assist you today?\n",
"\n", "\n",
"\n", "\n",
"I'm doing well, thank you. How about you?\n", "As an AI language model, I do not have emotions like humans, but I'm functioning normally. How can I assist you today?\n",
"\n", "\n",
"\n", "\n",
"I'm doing well, thank you. How about you?\n", "I am an AI language model, so I do not have feelings, but I am here to assist you. How may I help you today?\n",
"\n",
"I am doing quite well. How about you?\n",
"\n", "\n",
"\n", "\n",
"I'm doing well, thank you. How about you?\n", "As an AI language model, I do not have feelings or emotions but I'm always ready to assist you. How may I assist you today?\n",
"\n", "\n",
"\n", "\n",
"I'm doing great, thank you! How about you?\n", "As an AI language model, I don't have feelings, but I'm functioning normally. How may I assist you today?\n",
"\n", "\n",
"\n", "\n",
"I'm doing well, thanks for asking. How about you?\n", "As an AI language model, I don't have feelings, but I'm functioning properly. Thank you. How may I assist you today?\n",
"\n", "\n",
"\n", "\n",
"I'm doing well, thank you. How about you?\n", "As an AI language model, I don't have emotions, so I don't have a specific feeling or emotion. How can I assist you today?\n",
"\n", "\n",
"\n", "\n",
"I'm doing well, thank you. How about you?\n", "As an AI language model, I do not have feelings or emotions. However, I am functioning as intended and ready to assist you with any queries you may have. How can I be of assistance today?\n",
"\u001b[1mConcurrent executed in 1.93 seconds.\u001b[0m\n",
"\n", "\n",
"\n", "\n",
"I'm doing well, thank you. How about you?\n", "As an AI language model, I do not have feelings, but I am functioning well. Thank you for asking. How can I assist you today?\n",
"\u001b[1mConcurrent executed in 0.92 seconds.\u001b[0m\n",
"\n", "\n",
"\n", "\n",
"I'm doing well, thank you. How about you?\n", "As an AI language model, I don't have feelings, but I'm functioning well. How can I assist you today?\n",
"\n", "\n",
"\n", "\n",
"I'm doing well, thank you. How about you?\n", "As an AI language model, I don't have feelings, but I'm functioning well. Thank you for asking. How may I assist you today?\n",
"\n", "\n",
"\n", "\n",
"I'm doing well, thank you. How about you?\n", "I'm an AI language model, so I don't have feelings, but I'm functioning well. How can I assist you today?\n",
"\n", "\n",
"\n", "\n",
"I'm doing well, thank you. How about you?\n", "As an AI language model, I don't have feelings, but I'm functioning well. Thank you for asking. How may I assist you today?\n",
"\n", "\n",
"\n", "\n",
"I'm doing well, thank you. How about you?\n", "As an AI language model, I don't have feelings, but I am functioning well. How can I assist you today?\n",
"\n",
"I'm doing well, thank you. How about you?\n",
"\n", "\n",
"\n", "\n",
"I'm doing well, thank you. How about you?\n", "As an AI language model, I don't have feelings but I'm functioning well. How can I assist you today?\n",
"\n", "\n",
"\n", "\n",
"I'm doing well, thank you. How about you?\n", "As an AI language model, I do not have personal emotions. However, I am functioning well and ready to assist you with any queries or tasks you have. How may I assist you today?\n",
"\n", "\n",
"\n", "\n",
"I'm doing great, thank you. How about you?\n", "As an AI language model, I do not have feelings or emotions, but I'm functioning well. How can I assist you today?\n",
"\u001b[1mSerial executed in 10.54 seconds.\u001b[0m\n" "\n",
"\n",
"I am an AI language model and do not have feelings. But I am functioning properly and ready to assist you with any task. How may I help you today?\n",
"\n",
"\n",
"As an AI language model, I do not have emotions, but I am functioning well. How can I assist you today?\n",
"\u001b[1mSerial executed in 5.00 seconds.\u001b[0m\n"
] ]
} }
], ],
@ -94,10 +95,10 @@
"import time\n", "import time\n",
"import asyncio\n", "import asyncio\n",
"\n", "\n",
"from langchain.llms import OpenAI\n", "from langchain.llms import OpenAIChat\n",
"\n", "\n",
"def generate_serially():\n", "def generate_serially():\n",
" llm = OpenAI(temperature=0.9)\n", " llm = OpenAIChat(temperature=0.9)\n",
" for _ in range(10):\n", " for _ in range(10):\n",
" resp = llm.generate([\"Hello, how are you?\"])\n", " resp = llm.generate([\"Hello, how are you?\"])\n",
" print(resp.generations[0][0].text)\n", " print(resp.generations[0][0].text)\n",
@ -109,7 +110,7 @@
"\n", "\n",
"\n", "\n",
"async def generate_concurrently():\n", "async def generate_concurrently():\n",
" llm = OpenAI(temperature=0.9)\n", " llm = OpenAIChat(temperature=0.9)\n",
" tasks = [async_generate(llm) for _ in range(10)]\n", " tasks = [async_generate(llm) for _ in range(10)]\n",
" await asyncio.gather(*tasks)\n", " await asyncio.gather(*tasks)\n",
"\n", "\n",
@ -125,6 +126,14 @@
"elapsed = time.perf_counter() - s\n", "elapsed = time.perf_counter() - s\n",
"print('\\033[1m' + f\"Serial executed in {elapsed:0.2f} seconds.\" + '\\033[0m')" "print('\\033[1m' + f\"Serial executed in {elapsed:0.2f} seconds.\" + '\\033[0m')"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e1d3a966-3a27-44e8-9441-ed72f01b86f4",
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {

View File

@ -7,12 +7,12 @@
"source": [ "source": [
"# Streaming with LLMs\n", "# Streaming with LLMs\n",
"\n", "\n",
"LangChain provides streaming support for LLMs. Currently, we only support streaming for the `OpenAI` LLM implementation, but streaming support for other LLM implementations is on the roadmap. To utilize streaming, use a [`CallbackHandler`](https://github.com/hwchase17/langchain/blob/master/langchain/callbacks/base.py) that implements `on_llm_new_token`. In this example, we are using [`StreamingStdOutCallbackHandler`]()." "LangChain provides streaming support for LLMs. Currently, we only support streaming for the `OpenAI` and `OpenAIChat` LLM implementation, but streaming support for other LLM implementations is on the roadmap. To utilize streaming, use a [`CallbackHandler`](https://github.com/hwchase17/langchain/blob/master/langchain/callbacks/base.py) that implements `on_llm_new_token`. In this example, we are using [`StreamingStdOutCallbackHandler`]()."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 2,
"id": "4ac0ff54-540a-4f2b-8d9a-b590fec7fe07", "id": "4ac0ff54-540a-4f2b-8d9a-b590fec7fe07",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -27,43 +27,43 @@
"Verse 1\n", "Verse 1\n",
"I'm sippin' on sparkling water,\n", "I'm sippin' on sparkling water,\n",
"It's so refreshing and light,\n", "It's so refreshing and light,\n",
"It's the perfect way to quench my thirst,\n", "It's the perfect way to quench my thirst\n",
"On a hot summer night.\n", "On a hot summer night.\n",
"\n", "\n",
"Chorus\n", "Chorus\n",
"Sparkling water, sparkling water,\n", "Sparkling water, sparkling water,\n",
"It's the best way to stay hydrated,\n", "It's the best way to stay hydrated,\n",
"It's so refreshing and light,\n", "It's so crisp and so clean,\n",
"It's the perfect way to stay alive.\n", "It's the perfect way to stay refreshed.\n",
"\n", "\n",
"Verse 2\n", "Verse 2\n",
"I'm sippin' on sparkling water,\n", "I'm sippin' on sparkling water,\n",
"It's so bubbly and bright,\n", "It's so bubbly and bright,\n",
"It's the perfect way to cool me down,\n", "It's the perfect way to cool me down\n",
"On a hot summer night.\n", "On a hot summer night.\n",
"\n", "\n",
"Chorus\n", "Chorus\n",
"Sparkling water, sparkling water,\n", "Sparkling water, sparkling water,\n",
"It's the best way to stay hydrated,\n", "It's the best way to stay hydrated,\n",
"It's so refreshing and light,\n", "It's so crisp and so clean,\n",
"It's the perfect way to stay alive.\n", "It's the perfect way to stay refreshed.\n",
"\n", "\n",
"Verse 3\n", "Verse 3\n",
"I'm sippin' on sparkling water,\n", "I'm sippin' on sparkling water,\n",
"It's so crisp and clean,\n", "It's so light and so clear,\n",
"It's the perfect way to keep me going,\n", "It's the perfect way to keep me cool\n",
"On a hot summer day.\n", "On a hot summer night.\n",
"\n", "\n",
"Chorus\n", "Chorus\n",
"Sparkling water, sparkling water,\n", "Sparkling water, sparkling water,\n",
"It's the best way to stay hydrated,\n", "It's the best way to stay hydrated,\n",
"It's so refreshing and light,\n", "It's so crisp and so clean,\n",
"It's the perfect way to stay alive." "It's the perfect way to stay refreshed."
] ]
} }
], ],
"source": [ "source": [
"from langchain.llms import OpenAI\n", "from langchain.llms import OpenAI, OpenAIChat\n",
"from langchain.callbacks.base import CallbackManager\n", "from langchain.callbacks.base import CallbackManager\n",
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n", "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
"\n", "\n",
@ -84,7 +84,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 3,
"id": "a35373f1-9ee6-4753-a343-5aee749b8527", "id": "a35373f1-9ee6-4753-a343-5aee749b8527",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -103,10 +103,10 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"LLMResult(generations=[[Generation(text='\\n\\nQ: What did the fish say when it hit the wall?\\nA: Dam!', generation_info={'finish_reason': 'stop', 'logprobs': None})]], llm_output={'token_usage': {}})" "LLMResult(generations=[[Generation(text='\\n\\nQ: What did the fish say when it hit the wall?\\nA: Dam!', generation_info={'finish_reason': None, 'logprobs': None})]], llm_output={'token_usage': {}})"
] ]
}, },
"execution_count": 8, "execution_count": 3,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -114,6 +114,85 @@
"source": [ "source": [
"llm.generate([\"Tell me a joke.\"])" "llm.generate([\"Tell me a joke.\"])"
] ]
},
{
"cell_type": "markdown",
"id": "a93a4d61-0476-49db-8321-7de92bd74059",
"metadata": {},
"source": [
"Here's an example with `OpenAIChat`:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "22665f16-e05b-473c-a4bd-ad75744ea024",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Verse 1:\n",
"Bubbles rising to the top\n",
"A refreshing drink that never stops\n",
"Clear and crisp, it's pure delight\n",
"A taste that's sure to excite\n",
"\n",
"Chorus:\n",
"Sparkling water, oh so fine\n",
"A drink that's always on my mind\n",
"With every sip, I feel alive\n",
"Sparkling water, you're my vibe\n",
"\n",
"Verse 2:\n",
"No sugar, no calories, just pure bliss\n",
"A drink that's hard to resist\n",
"It's the perfect way to quench my thirst\n",
"A drink that always comes first\n",
"\n",
"Chorus:\n",
"Sparkling water, oh so fine\n",
"A drink that's always on my mind\n",
"With every sip, I feel alive\n",
"Sparkling water, you're my vibe\n",
"\n",
"Bridge:\n",
"From the mountains to the sea\n",
"Sparkling water, you're the key\n",
"To a healthy life, a happy soul\n",
"A drink that makes me feel whole\n",
"\n",
"Chorus:\n",
"Sparkling water, oh so fine\n",
"A drink that's always on my mind\n",
"With every sip, I feel alive\n",
"Sparkling water, you're my vibe\n",
"\n",
"Outro:\n",
"Sparkling water, you're the one\n",
"A drink that's always so much fun\n",
"I'll never let you go, my friend\n",
"Sparkling water, until the end."
]
}
],
"source": [
"llm = OpenAIChat(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
"resp = llm(\"Write me a song about sparkling water.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eadae4ba-9f21-4ec8-845d-dd43b0edc2dc",
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {

View File

@ -319,6 +319,10 @@ class LLM(BaseLLM):
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Run the LLM on the given prompt and input."""
raise NotImplementedError("Async generation not implemented for this LLM.")
def _generate( def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult: ) -> LLMResult:
@ -334,4 +338,8 @@ class LLM(BaseLLM):
self, prompts: List[str], stop: Optional[List[str]] = None self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
raise NotImplementedError("Async generation not implemented for this LLM.") generations = []
for prompt in prompts:
text = await self._acall(prompt, stop=stop)
generations.append([Generation(text=text)])
return LLMResult(generations=generations)

View File

@ -1,4 +1,6 @@
"""Wrapper around OpenAI APIs.""" """Wrapper around OpenAI APIs."""
from __future__ import annotations
import logging import logging
import sys import sys
from typing import ( from typing import (
@ -63,6 +65,53 @@ def _streaming_response_template() -> Dict[str, Any]:
} }
def _create_retry_decorator(llm: Union[BaseOpenAI, OpenAIChat]) -> Callable[[Any], Any]:
import openai
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(llm.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def completion_with_retry(llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return llm.client.create(**kwargs)
return _completion_with_retry(**kwargs)
async def acompletion_with_retry(
llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any
) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
return await llm.client.acreate(**kwargs)
return await _completion_with_retry(**kwargs)
class BaseOpenAI(BaseLLM, BaseModel): class BaseOpenAI(BaseLLM, BaseModel):
"""Wrapper around OpenAI large language models. """Wrapper around OpenAI large language models.
@ -174,48 +223,6 @@ class BaseOpenAI(BaseLLM, BaseModel):
} }
return {**normal_params, **self.model_kwargs} return {**normal_params, **self.model_kwargs}
def _create_retry_decorator(self) -> Callable[[Any], Any]:
import openai
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def completion_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = self._create_retry_decorator()
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return self.client.create(**kwargs)
return _completion_with_retry(**kwargs)
async def acompletion_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = self._create_retry_decorator()
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
return await self.client.acreate(**kwargs)
return await _completion_with_retry(**kwargs)
def _generate( def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult: ) -> LLMResult:
@ -247,8 +254,8 @@ class BaseOpenAI(BaseLLM, BaseModel):
raise ValueError("Cannot stream results with multiple prompts.") raise ValueError("Cannot stream results with multiple prompts.")
params["stream"] = True params["stream"] = True
response = _streaming_response_template() response = _streaming_response_template()
for stream_resp in self.completion_with_retry( for stream_resp in completion_with_retry(
prompt=_prompts, **params self, prompt=_prompts, **params
): ):
self.callback_manager.on_llm_new_token( self.callback_manager.on_llm_new_token(
stream_resp["choices"][0]["text"], stream_resp["choices"][0]["text"],
@ -258,7 +265,7 @@ class BaseOpenAI(BaseLLM, BaseModel):
_update_response(response, stream_resp) _update_response(response, stream_resp)
choices.extend(response["choices"]) choices.extend(response["choices"])
else: else:
response = self.completion_with_retry(prompt=_prompts, **params) response = completion_with_retry(self, prompt=_prompts, **params)
choices.extend(response["choices"]) choices.extend(response["choices"])
if not self.streaming: if not self.streaming:
# Can't update token usage if streaming # Can't update token usage if streaming
@ -282,8 +289,8 @@ class BaseOpenAI(BaseLLM, BaseModel):
raise ValueError("Cannot stream results with multiple prompts.") raise ValueError("Cannot stream results with multiple prompts.")
params["stream"] = True params["stream"] = True
response = _streaming_response_template() response = _streaming_response_template()
async for stream_resp in await self.acompletion_with_retry( async for stream_resp in await acompletion_with_retry(
prompt=_prompts, **params self, prompt=_prompts, **params
): ):
if self.callback_manager.is_async: if self.callback_manager.is_async:
await self.callback_manager.on_llm_new_token( await self.callback_manager.on_llm_new_token(
@ -300,7 +307,7 @@ class BaseOpenAI(BaseLLM, BaseModel):
_update_response(response, stream_resp) _update_response(response, stream_resp)
choices.extend(response["choices"]) choices.extend(response["choices"])
else: else:
response = await self.acompletion_with_retry(prompt=_prompts, **params) response = await acompletion_with_retry(self, prompt=_prompts, **params)
choices.extend(response["choices"]) choices.extend(response["choices"])
if not self.streaming: if not self.streaming:
# Can't update token usage if streaming # Can't update token usage if streaming
@ -540,6 +547,9 @@ class OpenAIChat(BaseLLM, BaseModel):
max_retries: int = 6 max_retries: int = 6
"""Maximum number of retries to make when generating.""" """Maximum number of retries to make when generating."""
prefix_messages: List = Field(default_factory=list) prefix_messages: List = Field(default_factory=list)
"""Series of messages for Chat input."""
streaming: bool = False
"""Whether to stream the results or not."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -590,61 +600,82 @@ class OpenAIChat(BaseLLM, BaseModel):
"""Get the default parameters for calling OpenAI API.""" """Get the default parameters for calling OpenAI API."""
return self.model_kwargs return self.model_kwargs
def _create_retry_decorator(self) -> Callable[[Any], Any]: def _get_chat_params(
import openai
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def completion_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = self._create_retry_decorator()
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return self.client.create(**kwargs)
return _completion_with_retry(**kwargs)
def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult: ) -> Tuple:
if len(prompts) > 1: if len(prompts) > 1:
raise ValueError(f"OpenAIChat only supports single prompts, got {prompts}") raise ValueError(
f"OpenAIChat currently only supports single prompt, got {prompts}"
)
messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}] messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}]
params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params} params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
if stop is not None: if stop is not None:
if "stop" in params: if "stop" in params:
raise ValueError("`stop` found in both the input and default params.") raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop params["stop"] = stop
response = self.completion_with_retry(messages=messages, **params) return messages, params
def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
messages, params = self._get_chat_params(prompts, stop)
if self.streaming:
response = ""
params["stream"] = True
for stream_resp in completion_with_retry(self, messages=messages, **params):
token = stream_resp["choices"][0]["delta"].get("content", "")
response += token
self.callback_manager.on_llm_new_token(
token,
verbose=self.verbose,
)
return LLMResult(
generations=[[Generation(text=response)]],
)
else:
full_response = completion_with_retry(self, messages=messages, **params)
return LLMResult( return LLMResult(
generations=[ generations=[
[Generation(text=response["choices"][0]["message"]["content"])] [Generation(text=full_response["choices"][0]["message"]["content"])]
], ],
llm_output={"token_usage": response["usage"]}, llm_output={"token_usage": full_response["usage"]},
) )
async def _agenerate( async def _agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompt and input.""" messages, params = self._get_chat_params(prompts, stop)
raise NotImplementedError("Async generation not implemented for this LLM.") if self.streaming:
response = ""
params["stream"] = True
async for stream_resp in await acompletion_with_retry(
self, messages=messages, **params
):
token = stream_resp["choices"][0]["delta"].get("content", "")
response += token
if self.callback_manager.is_async:
await self.callback_manager.on_llm_new_token(
token,
verbose=self.verbose,
)
else:
self.callback_manager.on_llm_new_token(
token,
verbose=self.verbose,
)
return LLMResult(
generations=[[Generation(text=response)]],
)
else:
full_response = await acompletion_with_retry(
self, messages=messages, **params
)
return LLMResult(
generations=[
[Generation(text=full_response["choices"][0]["message"]["content"])]
],
llm_output={"token_usage": full_response["usage"]},
)
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]:

View File

@ -7,7 +7,7 @@ import pytest
from langchain.callbacks.base import CallbackManager from langchain.callbacks.base import CallbackManager
from langchain.llms.loading import load_llm from langchain.llms.loading import load_llm
from langchain.llms.openai import OpenAI from langchain.llms.openai import OpenAI, OpenAIChat
from langchain.schema import LLMResult from langchain.schema import LLMResult
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@ -142,3 +142,57 @@ async def test_openai_async_streaming_callback() -> None:
result = await llm.agenerate(["Write me a sentence with 100 words."]) result = await llm.agenerate(["Write me a sentence with 100 words."])
assert callback_handler.llm_streams == 10 assert callback_handler.llm_streams == 10
assert isinstance(result, LLMResult) assert isinstance(result, LLMResult)
def test_openai_chat() -> None:
"""Test OpenAIChat."""
llm = OpenAIChat(max_tokens=10)
output = llm("Say foo:")
assert isinstance(output, str)
def test_openai_chat_streaming() -> None:
"""Test OpenAIChat with streaming option."""
llm = OpenAIChat(max_tokens=10, streaming=True)
output = llm("Say foo:")
assert isinstance(output, str)
def test_openai_chat_streaming_callback() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
llm = OpenAIChat(
max_tokens=10,
streaming=True,
temperature=0,
callback_manager=callback_manager,
verbose=True,
)
llm("Write me a sentence with 100 words.")
assert callback_handler.llm_streams != 0
@pytest.mark.asyncio
async def test_openai_chat_async_generate() -> None:
"""Test async chat."""
llm = OpenAIChat(max_tokens=10)
output = await llm.agenerate(["Hello, how are you?"])
assert isinstance(output, LLMResult)
@pytest.mark.asyncio
async def test_openai_chat_async_streaming_callback() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
llm = OpenAIChat(
max_tokens=10,
streaming=True,
temperature=0,
callback_manager=callback_manager,
verbose=True,
)
result = await llm.agenerate(["Write me a sentence with 100 words."])
assert callback_handler.llm_streams != 0
assert isinstance(result, LLMResult)