feat(llms): add streaming support to textgen (#9295)

- Description: Added streaming support to the textgen component in the
llms module.
  - Dependencies: websocket-client = "^1.6.1"
This commit is contained in:
Utku Ege Tuluk 2023-08-21 22:39:14 +08:00 committed by GitHub
parent a03003f5fd
commit bb4f7936f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 163 additions and 16 deletions

View File

@ -26,7 +26,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {
"tags": []
},
@ -61,6 +61,71 @@
"\n",
"llm_chain.run(question)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Streaming Version"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You should install websocket-client to use this feature.\n",
"`pip install websocket-client`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model_url = \"ws://localhost:5005\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import langchain\n",
"from langchain import PromptTemplate, LLMChain\n",
"from langchain.llms import TextGen\n",
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
"\n",
"langchain.debug = True\n",
"\n",
"template = \"\"\"Question: {question}\n",
"\n",
"Answer: Let's think step by step.\"\"\"\n",
"\n",
"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n",
"llm = TextGen(model_url=model_url, streaming=True, callbacks=[StreamingStdOutCallbackHandler()])\n",
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
"question = \"What NFL team won the Super Bowl in the year Justin Bieber was born?\"\n",
"\n",
"llm_chain.run(question)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm = TextGen(\n",
" model_url = model_url,\n",
" streaming=True\n",
")\n",
"for chunk in llm.stream(\"Ask 'Hi, how are you?' like a pirate:'\",\n",
" stop=[\"'\",\"\\n\"]):\n",
" print(chunk, end='', flush=True)"
]
}
],
"metadata": {
@ -79,7 +144,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.7"
"version": "3.10.4"
}
},
"nbformat": 4,

View File

@ -1,11 +1,13 @@
import json
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Iterator, List, Optional
import requests
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.pydantic_v1 import Field
from langchain.schema.output import GenerationChunk
logger = logging.getLogger(__name__)
@ -109,7 +111,7 @@ class TextGen(LLM):
"""A list of strings to stop generation when encountered."""
streaming: bool = False
"""Whether to stream the results, token by token (currently unimplemented)."""
"""Whether to stream the results, token by token."""
@property
def _default_params(self) -> Dict[str, Any]:
@ -198,8 +200,15 @@ class TextGen(LLM):
llm("Write a story about llamas.")
"""
if self.streaming:
raise ValueError("`streaming` option currently unsupported.")
combined_text_output = ""
for chunk in self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
combined_text_output += chunk.text
print(prompt + combined_text_output)
result = combined_text_output
else:
url = f"{self.model_url}/api/v1/generate"
params = self._get_parameters(stop)
request = params.copy()
@ -214,3 +223,76 @@ class TextGen(LLM):
result = ""
return result
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Yields results objects as they are generated in real time.
It also calls the callback manager's on_llm_new_token event with
similar parameters to the OpenAI LLM class method of the same name.
Args:
prompt: The prompts to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
A generator representing the stream of tokens being generated.
Yields:
A dictionary like objects containing a string token and metadata.
See text-generation-webui docs and below for more.
Example:
.. code-block:: python
from langchain.llms import TextGen
llm = TextGen(
model_url = "ws://localhost:5005"
streaming=True
)
for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'",
stop=["'","\n"]):
print(chunk, end='', flush=True)
"""
try:
import websocket
except ImportError:
raise ImportError(
"The `websocket-client` package is required for streaming."
)
params = {**self._get_parameters(stop), **kwargs}
url = f"{self.model_url}/api/v1/stream"
request = params.copy()
request["prompt"] = prompt
websocket_client = websocket.WebSocket()
websocket_client.connect(url)
websocket_client.send(json.dumps(request))
while True:
result = websocket_client.recv()
result = json.loads(result)
if result["event"] == "text_stream":
chunk = GenerationChunk(
text=result["text"],
generation_info=None,
)
yield chunk
elif result["event"] == "stream_end":
websocket_client.close()
return
if run_manager:
run_manager.on_llm_new_token(token=chunk.text)