Generator in Python Bindings - streaming yields tokens at a time (#895)

* generator method

* cleanup

* bump version number for clarity

* added replace in decode to avoid unicodedecode exception

* revert back to _build_prompt
This commit is contained in:
Richard Guo 2023-06-09 10:17:44 -04:00 committed by GitHub
parent bc7935e5f5
commit 4fb82afb00
3 changed files with 130 additions and 23 deletions

View File

@ -122,7 +122,6 @@ class GPT4All():
Returns: Returns:
Model file destination. Model file destination.
""" """
def get_download_url(model_filename): def get_download_url(model_filename):
if url: if url:
return url return url
@ -162,6 +161,8 @@ class GPT4All():
print("Model downloaded at: ", download_path) print("Model downloaded at: ", download_path)
return download_path return download_path
# TODO: this naming is just confusing now and needs to be deprecated now that we have generator
# Need to better consolidate all these different model response methods
def generate(self, prompt: str, streaming: bool = True, **generate_kwargs) -> str: def generate(self, prompt: str, streaming: bool = True, **generate_kwargs) -> str:
""" """
Surfaced method of running generate without accessing model object. Surfaced method of running generate without accessing model object.
@ -174,7 +175,21 @@ class GPT4All():
Returns: Returns:
Raw string of generated model response. Raw string of generated model response.
""" """
return self.model.generate(prompt, streaming=streaming, **generate_kwargs) return self.model.prompt_model(prompt, streaming=streaming, **generate_kwargs)
def generator(self, prompt: str, **generate_kwargs) -> str:
"""
Surfaced method of running generate without accessing model object.
Args:
prompt: Raw string to be passed to model.
streaming: True if want output streamed to stdout.
**generate_kwargs: Optional kwargs to pass to prompt context.
Returns:
Raw string of generated model response.
"""
return self.model.generator(prompt, **generate_kwargs)
def chat_completion(self, def chat_completion(self,
messages: List[Dict], messages: List[Dict],
@ -209,14 +224,13 @@ class GPT4All():
"choices": List of message dictionary where "content" is generated response and "role" is set "choices": List of message dictionary where "content" is generated response and "role" is set
as "assistant". Right now, only one choice is returned by model. as "assistant". Right now, only one choice is returned by model.
""" """
full_prompt = self._build_prompt(messages, full_prompt = self._build_prompt(messages,
default_prompt_header=default_prompt_header, default_prompt_header=default_prompt_header,
default_prompt_footer=default_prompt_footer) default_prompt_footer=default_prompt_footer)
if verbose: if verbose:
print(full_prompt) print(full_prompt)
response = self.model.generate(full_prompt, streaming=streaming, **generate_kwargs) response = self.model.prompt_model(full_prompt, streaming=streaming, **generate_kwargs)
if verbose and not streaming: if verbose and not streaming:
print(response) print(response)
@ -241,8 +255,23 @@ class GPT4All():
@staticmethod @staticmethod
def _build_prompt(messages: List[Dict], def _build_prompt(messages: List[Dict],
default_prompt_header=True, default_prompt_header=True,
default_prompt_footer=False) -> str: default_prompt_footer=True) -> str:
# Helper method to format messages into prompt. """
Helper method for buildilng a prompt using template from list of messages.
Args:
messages: List of dictionaries. Each dictionary should have a "role" key
with value of "system", "assistant", or "user" and a "content" key with a
string value. Messages are organized such that "system" messages are at top of prompt,
and "user" and "assistant" messages are displayed in order. Assistant messages get formatted as
"Response: {content}".
default_prompt_header: If True (default), add default prompt header after any system role messages and
before user/assistant role messages.
default_prompt_footer: If True (default), add default footer at end of prompt.
Returns:
Formatted prompt.
"""
full_prompt = "" full_prompt = ""
for message in messages: for message in messages:

View File

@ -2,9 +2,11 @@ import pkg_resources
import ctypes import ctypes
import os import os
import platform import platform
import queue
import re import re
import subprocess import subprocess
import sys import sys
import threading
class DualStreamProcessor: class DualStreamProcessor:
def __init__(self, stream=None): def __init__(self, stream=None):
@ -167,21 +169,21 @@ class LLModel:
raise Exception("Model not loaded") raise Exception("Model not loaded")
return llmodel.llmodel_threadCount(self.model) return llmodel.llmodel_threadCount(self.model)
def generate(self, def prompt_model(self,
prompt: str, prompt: str,
logits_size: int = 0, logits_size: int = 0,
tokens_size: int = 0, tokens_size: int = 0,
n_past: int = 0, n_past: int = 0,
n_ctx: int = 1024, n_ctx: int = 1024,
n_predict: int = 128, n_predict: int = 128,
top_k: int = 40, top_k: int = 40,
top_p: float = .9, top_p: float = .9,
temp: float = .1, temp: float = .1,
n_batch: int = 8, n_batch: int = 8,
repeat_penalty: float = 1.2, repeat_penalty: float = 1.2,
repeat_last_n: int = 10, repeat_last_n: int = 10,
context_erase: float = .5, context_erase: float = .5,
streaming: bool = False) -> str: streaming: bool = True) -> str:
""" """
Generate response from model from a prompt. Generate response from model from a prompt.
@ -237,6 +239,82 @@ class LLModel:
print() print()
return stream_processor.output return stream_processor.output
def generator(self,
prompt: str,
logits_size: int = 0,
tokens_size: int = 0,
n_past: int = 0,
n_ctx: int = 1024,
n_predict: int = 128,
top_k: int = 40,
top_p: float = .9,
temp: float = .1,
n_batch: int = 8,
repeat_penalty: float = 1.2,
repeat_last_n: int = 10,
context_erase: float = .5) -> str:
# Symbol to terminate from generator
TERMINATING_SYMBOL = "#TERMINATE#"
output_queue = queue.Queue()
prompt = prompt.encode('utf-8')
prompt = ctypes.c_char_p(prompt)
context = LLModelPromptContext(
logits_size=logits_size,
tokens_size=tokens_size,
n_past=n_past,
n_ctx=n_ctx,
n_predict=n_predict,
top_k=top_k,
top_p=top_p,
temp=temp,
n_batch=n_batch,
repeat_penalty=repeat_penalty,
repeat_last_n=repeat_last_n,
context_erase=context_erase
)
# Put response tokens into an output queue
def _generator_response_callback(token_id, response):
output_queue.put(response.decode('utf-8', 'replace'))
return True
def run_llmodel_prompt(model,
prompt,
prompt_callback,
response_callback,
recalculate_callback,
context):
llmodel.llmodel_prompt(model,
prompt,
prompt_callback,
response_callback,
recalculate_callback,
context)
output_queue.put(TERMINATING_SYMBOL)
# Kick off llmodel_prompt in separate thread so we can return generator
# immediately
thread = threading.Thread(target=run_llmodel_prompt,
args=(self.model,
prompt,
PromptCallback(self._prompt_callback),
ResponseCallback(_generator_response_callback),
RecalculateCallback(self._recalculate_callback),
context))
thread.start()
# Generator
while True:
response = output_queue.get()
if response == TERMINATING_SYMBOL:
break
yield response
# Empty prompt callback # Empty prompt callback
@staticmethod @staticmethod
def _prompt_callback(token_id): def _prompt_callback(token_id):
@ -245,7 +323,7 @@ class LLModel:
# Empty response callback method that just prints response to be collected # Empty response callback method that just prints response to be collected
@staticmethod @staticmethod
def _response_callback(token_id, response): def _response_callback(token_id, response):
sys.stdout.write(response.decode('utf-8')) sys.stdout.write(response.decode('utf-8', 'replace'))
return True return True
# Empty recalculate callback # Empty recalculate callback

View File

@ -61,7 +61,7 @@ copy_prebuilt_C_lib(SRC_CLIB_DIRECtORY,
setup( setup(
name=package_name, name=package_name,
version="0.3.1", version="0.3.2",
description="Python bindings for GPT4All", description="Python bindings for GPT4All",
author="Richard Guo", author="Richard Guo",
author_email="richard@nomic.ai", author_email="richard@nomic.ai",