mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 06:39:52 +00:00
feat(llms): add streaming support to textgen (#9295)
- Description: Added streaming support to the textgen component in the llms module. - Dependencies: websocket-client = "^1.6.1"
This commit is contained in:
parent
a03003f5fd
commit
bb4f7936f9
@ -26,7 +26,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": null,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
},
|
},
|
||||||
@ -61,6 +61,71 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"llm_chain.run(question)"
|
"llm_chain.run(question)"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Streaming Version"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"You should install websocket-client to use this feature.\n",
|
||||||
|
"`pip install websocket-client`"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model_url = \"ws://localhost:5005\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import langchain\n",
|
||||||
|
"from langchain import PromptTemplate, LLMChain\n",
|
||||||
|
"from langchain.llms import TextGen\n",
|
||||||
|
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
|
||||||
|
"\n",
|
||||||
|
"langchain.debug = True\n",
|
||||||
|
"\n",
|
||||||
|
"template = \"\"\"Question: {question}\n",
|
||||||
|
"\n",
|
||||||
|
"Answer: Let's think step by step.\"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n",
|
||||||
|
"llm = TextGen(model_url=model_url, streaming=True, callbacks=[StreamingStdOutCallbackHandler()])\n",
|
||||||
|
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
|
||||||
|
"question = \"What NFL team won the Super Bowl in the year Justin Bieber was born?\"\n",
|
||||||
|
"\n",
|
||||||
|
"llm_chain.run(question)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm = TextGen(\n",
|
||||||
|
" model_url = model_url,\n",
|
||||||
|
" streaming=True\n",
|
||||||
|
")\n",
|
||||||
|
"for chunk in llm.stream(\"Ask 'Hi, how are you?' like a pirate:'\",\n",
|
||||||
|
" stop=[\"'\",\"\\n\"]):\n",
|
||||||
|
" print(chunk, end='', flush=True)"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@ -79,7 +144,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.7"
|
"version": "3.10.4"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, Iterator, List, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.pydantic_v1 import Field
|
from langchain.pydantic_v1 import Field
|
||||||
|
from langchain.schema.output import GenerationChunk
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -109,7 +111,7 @@ class TextGen(LLM):
|
|||||||
"""A list of strings to stop generation when encountered."""
|
"""A list of strings to stop generation when encountered."""
|
||||||
|
|
||||||
streaming: bool = False
|
streaming: bool = False
|
||||||
"""Whether to stream the results, token by token (currently unimplemented)."""
|
"""Whether to stream the results, token by token."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _default_params(self) -> Dict[str, Any]:
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
@ -198,19 +200,99 @@ class TextGen(LLM):
|
|||||||
llm("Write a story about llamas.")
|
llm("Write a story about llamas.")
|
||||||
"""
|
"""
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
raise ValueError("`streaming` option currently unsupported.")
|
combined_text_output = ""
|
||||||
|
for chunk in self._stream(
|
||||||
|
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
):
|
||||||
|
combined_text_output += chunk.text
|
||||||
|
print(prompt + combined_text_output)
|
||||||
|
result = combined_text_output
|
||||||
|
|
||||||
url = f"{self.model_url}/api/v1/generate"
|
|
||||||
params = self._get_parameters(stop)
|
|
||||||
request = params.copy()
|
|
||||||
request["prompt"] = prompt
|
|
||||||
response = requests.post(url, json=request)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
result = response.json()["results"][0]["text"]
|
|
||||||
print(prompt + result)
|
|
||||||
else:
|
else:
|
||||||
print(f"ERROR: response: {response}")
|
url = f"{self.model_url}/api/v1/generate"
|
||||||
result = ""
|
params = self._get_parameters(stop)
|
||||||
|
request = params.copy()
|
||||||
|
request["prompt"] = prompt
|
||||||
|
response = requests.post(url, json=request)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
result = response.json()["results"][0]["text"]
|
||||||
|
print(prompt + result)
|
||||||
|
else:
|
||||||
|
print(f"ERROR: response: {response}")
|
||||||
|
result = ""
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[GenerationChunk]:
|
||||||
|
"""Yields results objects as they are generated in real time.
|
||||||
|
|
||||||
|
It also calls the callback manager's on_llm_new_token event with
|
||||||
|
similar parameters to the OpenAI LLM class method of the same name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompts to pass into the model.
|
||||||
|
stop: Optional list of stop words to use when generating.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A generator representing the stream of tokens being generated.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
A dictionary like objects containing a string token and metadata.
|
||||||
|
See text-generation-webui docs and below for more.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.llms import TextGen
|
||||||
|
llm = TextGen(
|
||||||
|
model_url = "ws://localhost:5005"
|
||||||
|
streaming=True
|
||||||
|
)
|
||||||
|
for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'",
|
||||||
|
stop=["'","\n"]):
|
||||||
|
print(chunk, end='', flush=True)
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import websocket
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"The `websocket-client` package is required for streaming."
|
||||||
|
)
|
||||||
|
|
||||||
|
params = {**self._get_parameters(stop), **kwargs}
|
||||||
|
|
||||||
|
url = f"{self.model_url}/api/v1/stream"
|
||||||
|
|
||||||
|
request = params.copy()
|
||||||
|
request["prompt"] = prompt
|
||||||
|
|
||||||
|
websocket_client = websocket.WebSocket()
|
||||||
|
|
||||||
|
websocket_client.connect(url)
|
||||||
|
|
||||||
|
websocket_client.send(json.dumps(request))
|
||||||
|
|
||||||
|
while True:
|
||||||
|
result = websocket_client.recv()
|
||||||
|
result = json.loads(result)
|
||||||
|
|
||||||
|
if result["event"] == "text_stream":
|
||||||
|
chunk = GenerationChunk(
|
||||||
|
text=result["text"],
|
||||||
|
generation_info=None,
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
elif result["event"] == "stream_end":
|
||||||
|
websocket_client.close()
|
||||||
|
return
|
||||||
|
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(token=chunk.text)
|
||||||
|
Loading…
Reference in New Issue
Block a user