Generator in Python Bindings - streaming yields tokens at a time (#895)

* generator method

* cleanup

* bump version number for clarity

* added replace in decode to avoid unicodedecode exception

* revert back to _build_prompt
This commit is contained in:
Richard Guo 2023-06-09 10:17:44 -04:00 committed by GitHub
parent bc7935e5f5
commit 4fb82afb00
3 changed files with 130 additions and 23 deletions

View File

@ -122,7 +122,6 @@ class GPT4All():
Returns:
Model file destination.
"""
def get_download_url(model_filename):
if url:
return url
@ -162,6 +161,8 @@ class GPT4All():
print("Model downloaded at: ", download_path)
return download_path
# TODO: this naming is just confusing now and needs to be deprecated now that we have generator
# Need to better consolidate all these different model response methods
def generate(self, prompt: str, streaming: bool = True, **generate_kwargs) -> str:
"""
Surfaced method of running generate without accessing model object.
@ -174,7 +175,21 @@ class GPT4All():
Returns:
Raw string of generated model response.
"""
return self.model.generate(prompt, streaming=streaming, **generate_kwargs)
return self.model.prompt_model(prompt, streaming=streaming, **generate_kwargs)
def generator(self, prompt: str, **generate_kwargs) -> str:
"""
Surfaced method of running generate without accessing model object.
Args:
prompt: Raw string to be passed to model.
streaming: True if want output streamed to stdout.
**generate_kwargs: Optional kwargs to pass to prompt context.
Returns:
Raw string of generated model response.
"""
return self.model.generator(prompt, **generate_kwargs)
def chat_completion(self,
messages: List[Dict],
@ -209,14 +224,13 @@ class GPT4All():
"choices": List of message dictionary where "content" is generated response and "role" is set
as "assistant". Right now, only one choice is returned by model.
"""
full_prompt = self._build_prompt(messages,
default_prompt_header=default_prompt_header,
default_prompt_footer=default_prompt_footer)
if verbose:
print(full_prompt)
response = self.model.generate(full_prompt, streaming=streaming, **generate_kwargs)
response = self.model.prompt_model(full_prompt, streaming=streaming, **generate_kwargs)
if verbose and not streaming:
print(response)
@ -241,8 +255,23 @@ class GPT4All():
@staticmethod
def _build_prompt(messages: List[Dict],
default_prompt_header=True,
default_prompt_footer=False) -> str:
# Helper method to format messages into prompt.
default_prompt_footer=True) -> str:
"""
Helper method for buildilng a prompt using template from list of messages.
Args:
messages: List of dictionaries. Each dictionary should have a "role" key
with value of "system", "assistant", or "user" and a "content" key with a
string value. Messages are organized such that "system" messages are at top of prompt,
and "user" and "assistant" messages are displayed in order. Assistant messages get formatted as
"Response: {content}".
default_prompt_header: If True (default), add default prompt header after any system role messages and
before user/assistant role messages.
default_prompt_footer: If True (default), add default footer at end of prompt.
Returns:
Formatted prompt.
"""
full_prompt = ""
for message in messages:

View File

@ -2,9 +2,11 @@ import pkg_resources
import ctypes
import os
import platform
import queue
import re
import subprocess
import sys
import threading
class DualStreamProcessor:
def __init__(self, stream=None):
@ -167,21 +169,21 @@ class LLModel:
raise Exception("Model not loaded")
return llmodel.llmodel_threadCount(self.model)
def generate(self,
prompt: str,
logits_size: int = 0,
tokens_size: int = 0,
n_past: int = 0,
n_ctx: int = 1024,
n_predict: int = 128,
top_k: int = 40,
top_p: float = .9,
temp: float = .1,
n_batch: int = 8,
repeat_penalty: float = 1.2,
repeat_last_n: int = 10,
context_erase: float = .5,
streaming: bool = False) -> str:
def prompt_model(self,
prompt: str,
logits_size: int = 0,
tokens_size: int = 0,
n_past: int = 0,
n_ctx: int = 1024,
n_predict: int = 128,
top_k: int = 40,
top_p: float = .9,
temp: float = .1,
n_batch: int = 8,
repeat_penalty: float = 1.2,
repeat_last_n: int = 10,
context_erase: float = .5,
streaming: bool = True) -> str:
"""
Generate response from model from a prompt.
@ -237,6 +239,82 @@ class LLModel:
print()
return stream_processor.output
def generator(self,
prompt: str,
logits_size: int = 0,
tokens_size: int = 0,
n_past: int = 0,
n_ctx: int = 1024,
n_predict: int = 128,
top_k: int = 40,
top_p: float = .9,
temp: float = .1,
n_batch: int = 8,
repeat_penalty: float = 1.2,
repeat_last_n: int = 10,
context_erase: float = .5) -> str:
# Symbol to terminate from generator
TERMINATING_SYMBOL = "#TERMINATE#"
output_queue = queue.Queue()
prompt = prompt.encode('utf-8')
prompt = ctypes.c_char_p(prompt)
context = LLModelPromptContext(
logits_size=logits_size,
tokens_size=tokens_size,
n_past=n_past,
n_ctx=n_ctx,
n_predict=n_predict,
top_k=top_k,
top_p=top_p,
temp=temp,
n_batch=n_batch,
repeat_penalty=repeat_penalty,
repeat_last_n=repeat_last_n,
context_erase=context_erase
)
# Put response tokens into an output queue
def _generator_response_callback(token_id, response):
output_queue.put(response.decode('utf-8', 'replace'))
return True
def run_llmodel_prompt(model,
prompt,
prompt_callback,
response_callback,
recalculate_callback,
context):
llmodel.llmodel_prompt(model,
prompt,
prompt_callback,
response_callback,
recalculate_callback,
context)
output_queue.put(TERMINATING_SYMBOL)
# Kick off llmodel_prompt in separate thread so we can return generator
# immediately
thread = threading.Thread(target=run_llmodel_prompt,
args=(self.model,
prompt,
PromptCallback(self._prompt_callback),
ResponseCallback(_generator_response_callback),
RecalculateCallback(self._recalculate_callback),
context))
thread.start()
# Generator
while True:
response = output_queue.get()
if response == TERMINATING_SYMBOL:
break
yield response
# Empty prompt callback
@staticmethod
def _prompt_callback(token_id):
@ -245,7 +323,7 @@ class LLModel:
# Empty response callback method that just prints response to be collected
@staticmethod
def _response_callback(token_id, response):
sys.stdout.write(response.decode('utf-8'))
sys.stdout.write(response.decode('utf-8', 'replace'))
return True
# Empty recalculate callback

View File

@ -61,7 +61,7 @@ copy_prebuilt_C_lib(SRC_CLIB_DIRECtORY,
setup(
name=package_name,
version="0.3.1",
version="0.3.2",
description="Python bindings for GPT4All",
author="Richard Guo",
author_email="richard@nomic.ai",