TextGen is missing async methods. (#9986)

Adding _acall and _astream method that were missing. Preventing
streaming during async executions.

 @rlancemartin.
This commit is contained in:
German Martin 2023-09-03 18:57:40 -03:00 committed by GitHub
parent f4bed8a04c
commit cf5a50469f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,10 +1,13 @@
import json import json
import logging import logging
from typing import Any, Dict, Iterator, List, Optional from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
import requests import requests
from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.pydantic_v1 import Field from langchain.pydantic_v1 import Field
from langchain.schema.output import GenerationChunk from langchain.schema.output import GenerationChunk
@ -224,6 +227,54 @@ class TextGen(LLM):
return result return result
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call the textgen web API and return the output.
Args:
prompt: The prompt to use for generation.
stop: A list of strings to stop generation when encountered.
Returns:
The generated text.
Example:
.. code-block:: python
from langchain.llms import TextGen
llm = TextGen(model_url="http://localhost:5000")
llm("Write a story about llamas.")
"""
if self.streaming:
combined_text_output = ""
async for chunk in self._astream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
combined_text_output += chunk.text
print(prompt + combined_text_output)
result = combined_text_output
else:
url = f"{self.model_url}/api/v1/generate"
params = self._get_parameters(stop)
request = params.copy()
request["prompt"] = prompt
response = requests.post(url, json=request)
if response.status_code == 200:
result = response.json()["results"][0]["text"]
print(prompt + result)
else:
print(f"ERROR: response: {response}")
result = ""
return result
def _stream( def _stream(
self, self,
prompt: str, prompt: str,
@ -296,3 +347,76 @@ class TextGen(LLM):
if run_manager: if run_manager:
run_manager.on_llm_new_token(token=chunk.text) run_manager.on_llm_new_token(token=chunk.text)
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
"""Yields results objects as they are generated in real time.
It also calls the callback manager's on_llm_new_token event with
similar parameters to the OpenAI LLM class method of the same name.
Args:
prompt: The prompts to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
A generator representing the stream of tokens being generated.
Yields:
A dictionary like objects containing a string token and metadata.
See text-generation-webui docs and below for more.
Example:
.. code-block:: python
from langchain.llms import TextGen
llm = TextGen(
model_url = "ws://localhost:5005"
streaming=True
)
for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'",
stop=["'","\n"]):
print(chunk, end='', flush=True)
"""
try:
import websocket
except ImportError:
raise ImportError(
"The `websocket-client` package is required for streaming."
)
params = {**self._get_parameters(stop), **kwargs}
url = f"{self.model_url}/api/v1/stream"
request = params.copy()
request["prompt"] = prompt
websocket_client = websocket.WebSocket()
websocket_client.connect(url)
websocket_client.send(json.dumps(request))
while True:
result = websocket_client.recv()
result = json.loads(result)
if result["event"] == "text_stream":
chunk = GenerationChunk(
text=result["text"],
generation_info=None,
)
yield chunk
elif result["event"] == "stream_end":
websocket_client.close()
return
if run_manager:
await run_manager.on_llm_new_token(token=chunk.text)