From bb4f7936f91a1b6d26c4c73cc4e1a6612d732b8f Mon Sep 17 00:00:00 2001 From: Utku Ege Tuluk <36725861+uetuluk@users.noreply.github.com> Date: Mon, 21 Aug 2023 22:39:14 +0800 Subject: [PATCH] feat(llms): add streaming support to textgen (#9295) - Description: Added streaming support to the textgen component in the llms module. - Dependencies: websocket-client = "^1.6.1" --- docs/extras/integrations/llms/textgen.ipynb | 69 +++++++++++- libs/langchain/langchain/llms/textgen.py | 110 +++++++++++++++++--- 2 files changed, 163 insertions(+), 16 deletions(-) diff --git a/docs/extras/integrations/llms/textgen.ipynb b/docs/extras/integrations/llms/textgen.ipynb index 490e3a4b370..3ffd83e69f9 100644 --- a/docs/extras/integrations/llms/textgen.ipynb +++ b/docs/extras/integrations/llms/textgen.ipynb @@ -26,7 +26,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "tags": [] }, @@ -61,6 +61,71 @@ "\n", "llm_chain.run(question)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Streaming Version" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You should install websocket-client to use this feature.\n", + "`pip install websocket-client`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_url = \"ws://localhost:5005\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import langchain\n", + "from langchain import PromptTemplate, LLMChain\n", + "from langchain.llms import TextGen\n", + "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n", + "\n", + "langchain.debug = True\n", + "\n", + "template = \"\"\"Question: {question}\n", + "\n", + "Answer: Let's think step by step.\"\"\"\n", + "\n", + "\n", + "prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n", + "llm = TextGen(model_url=model_url, streaming=True, callbacks=[StreamingStdOutCallbackHandler()])\n", + "llm_chain = LLMChain(prompt=prompt, llm=llm)\n", + "question = \"What NFL team won the Super Bowl in the year Justin Bieber was born?\"\n", + "\n", + "llm_chain.run(question)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "llm = TextGen(\n", + " model_url = model_url,\n", + " streaming=True\n", + ")\n", + "for chunk in llm.stream(\"Ask 'Hi, how are you?' like a pirate:'\",\n", + " stop=[\"'\",\"\\n\"]):\n", + " print(chunk, end='', flush=True)" + ] } ], "metadata": { @@ -79,7 +144,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.7" + "version": "3.10.4" } }, "nbformat": 4, diff --git a/libs/langchain/langchain/llms/textgen.py b/libs/langchain/langchain/llms/textgen.py index bcee0eb83ed..0d846ecc882 100644 --- a/libs/langchain/langchain/llms/textgen.py +++ b/libs/langchain/langchain/llms/textgen.py @@ -1,11 +1,13 @@ +import json import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Iterator, List, Optional import requests from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.pydantic_v1 import Field +from langchain.schema.output import GenerationChunk logger = logging.getLogger(__name__) @@ -109,7 +111,7 @@ class TextGen(LLM): """A list of strings to stop generation when encountered.""" streaming: bool = False - """Whether to stream the results, token by token (currently unimplemented).""" + """Whether to stream the results, token by token.""" @property def _default_params(self) -> Dict[str, Any]: @@ -198,19 +200,99 @@ class TextGen(LLM): llm("Write a story about llamas.") """ if self.streaming: - raise ValueError("`streaming` option currently unsupported.") + combined_text_output = "" + for chunk in self._stream( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ): + combined_text_output += chunk.text + print(prompt + combined_text_output) + result = combined_text_output - url = f"{self.model_url}/api/v1/generate" - params = self._get_parameters(stop) - request = params.copy() - request["prompt"] = prompt - response = requests.post(url, json=request) - - if response.status_code == 200: - result = response.json()["results"][0]["text"] - print(prompt + result) else: - print(f"ERROR: response: {response}") - result = "" + url = f"{self.model_url}/api/v1/generate" + params = self._get_parameters(stop) + request = params.copy() + request["prompt"] = prompt + response = requests.post(url, json=request) + + if response.status_code == 200: + result = response.json()["results"][0]["text"] + print(prompt + result) + else: + print(f"ERROR: response: {response}") + result = "" return result + + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + """Yields results objects as they are generated in real time. + + It also calls the callback manager's on_llm_new_token event with + similar parameters to the OpenAI LLM class method of the same name. + + Args: + prompt: The prompts to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + A generator representing the stream of tokens being generated. + + Yields: + A dictionary like objects containing a string token and metadata. + See text-generation-webui docs and below for more. + + Example: + .. code-block:: python + + from langchain.llms import TextGen + llm = TextGen( + model_url = "ws://localhost:5005" + streaming=True + ) + for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'", + stop=["'","\n"]): + print(chunk, end='', flush=True) + + """ + try: + import websocket + except ImportError: + raise ImportError( + "The `websocket-client` package is required for streaming." + ) + + params = {**self._get_parameters(stop), **kwargs} + + url = f"{self.model_url}/api/v1/stream" + + request = params.copy() + request["prompt"] = prompt + + websocket_client = websocket.WebSocket() + + websocket_client.connect(url) + + websocket_client.send(json.dumps(request)) + + while True: + result = websocket_client.recv() + result = json.loads(result) + + if result["event"] == "text_stream": + chunk = GenerationChunk( + text=result["text"], + generation_info=None, + ) + yield chunk + elif result["event"] == "stream_end": + websocket_client.close() + return + + if run_manager: + run_manager.on_llm_new_token(token=chunk.text)