mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-21 13:10:35 +00:00
Generator in Python Bindings - streaming yields tokens at a time (#895)
* generator method * cleanup * bump version number for clarity * added replace in decode to avoid unicodedecode exception * revert back to _build_prompt
This commit is contained in:
parent
bc7935e5f5
commit
4fb82afb00
@ -122,7 +122,6 @@ class GPT4All():
|
|||||||
Returns:
|
Returns:
|
||||||
Model file destination.
|
Model file destination.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_download_url(model_filename):
|
def get_download_url(model_filename):
|
||||||
if url:
|
if url:
|
||||||
return url
|
return url
|
||||||
@ -162,6 +161,8 @@ class GPT4All():
|
|||||||
print("Model downloaded at: ", download_path)
|
print("Model downloaded at: ", download_path)
|
||||||
return download_path
|
return download_path
|
||||||
|
|
||||||
|
# TODO: this naming is just confusing now and needs to be deprecated now that we have generator
|
||||||
|
# Need to better consolidate all these different model response methods
|
||||||
def generate(self, prompt: str, streaming: bool = True, **generate_kwargs) -> str:
|
def generate(self, prompt: str, streaming: bool = True, **generate_kwargs) -> str:
|
||||||
"""
|
"""
|
||||||
Surfaced method of running generate without accessing model object.
|
Surfaced method of running generate without accessing model object.
|
||||||
@ -174,7 +175,21 @@ class GPT4All():
|
|||||||
Returns:
|
Returns:
|
||||||
Raw string of generated model response.
|
Raw string of generated model response.
|
||||||
"""
|
"""
|
||||||
return self.model.generate(prompt, streaming=streaming, **generate_kwargs)
|
return self.model.prompt_model(prompt, streaming=streaming, **generate_kwargs)
|
||||||
|
|
||||||
|
def generator(self, prompt: str, **generate_kwargs) -> str:
|
||||||
|
"""
|
||||||
|
Surfaced method of running generate without accessing model object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Raw string to be passed to model.
|
||||||
|
streaming: True if want output streamed to stdout.
|
||||||
|
**generate_kwargs: Optional kwargs to pass to prompt context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Raw string of generated model response.
|
||||||
|
"""
|
||||||
|
return self.model.generator(prompt, **generate_kwargs)
|
||||||
|
|
||||||
def chat_completion(self,
|
def chat_completion(self,
|
||||||
messages: List[Dict],
|
messages: List[Dict],
|
||||||
@ -209,14 +224,13 @@ class GPT4All():
|
|||||||
"choices": List of message dictionary where "content" is generated response and "role" is set
|
"choices": List of message dictionary where "content" is generated response and "role" is set
|
||||||
as "assistant". Right now, only one choice is returned by model.
|
as "assistant". Right now, only one choice is returned by model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
full_prompt = self._build_prompt(messages,
|
full_prompt = self._build_prompt(messages,
|
||||||
default_prompt_header=default_prompt_header,
|
default_prompt_header=default_prompt_header,
|
||||||
default_prompt_footer=default_prompt_footer)
|
default_prompt_footer=default_prompt_footer)
|
||||||
if verbose:
|
if verbose:
|
||||||
print(full_prompt)
|
print(full_prompt)
|
||||||
|
|
||||||
response = self.model.generate(full_prompt, streaming=streaming, **generate_kwargs)
|
response = self.model.prompt_model(full_prompt, streaming=streaming, **generate_kwargs)
|
||||||
|
|
||||||
if verbose and not streaming:
|
if verbose and not streaming:
|
||||||
print(response)
|
print(response)
|
||||||
@ -241,8 +255,23 @@ class GPT4All():
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_prompt(messages: List[Dict],
|
def _build_prompt(messages: List[Dict],
|
||||||
default_prompt_header=True,
|
default_prompt_header=True,
|
||||||
default_prompt_footer=False) -> str:
|
default_prompt_footer=True) -> str:
|
||||||
# Helper method to format messages into prompt.
|
"""
|
||||||
|
Helper method for buildilng a prompt using template from list of messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of dictionaries. Each dictionary should have a "role" key
|
||||||
|
with value of "system", "assistant", or "user" and a "content" key with a
|
||||||
|
string value. Messages are organized such that "system" messages are at top of prompt,
|
||||||
|
and "user" and "assistant" messages are displayed in order. Assistant messages get formatted as
|
||||||
|
"Response: {content}".
|
||||||
|
default_prompt_header: If True (default), add default prompt header after any system role messages and
|
||||||
|
before user/assistant role messages.
|
||||||
|
default_prompt_footer: If True (default), add default footer at end of prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted prompt.
|
||||||
|
"""
|
||||||
full_prompt = ""
|
full_prompt = ""
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
|
@ -2,9 +2,11 @@ import pkg_resources
|
|||||||
import ctypes
|
import ctypes
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
|
import queue
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
|
|
||||||
class DualStreamProcessor:
|
class DualStreamProcessor:
|
||||||
def __init__(self, stream=None):
|
def __init__(self, stream=None):
|
||||||
@ -167,21 +169,21 @@ class LLModel:
|
|||||||
raise Exception("Model not loaded")
|
raise Exception("Model not loaded")
|
||||||
return llmodel.llmodel_threadCount(self.model)
|
return llmodel.llmodel_threadCount(self.model)
|
||||||
|
|
||||||
def generate(self,
|
def prompt_model(self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
logits_size: int = 0,
|
logits_size: int = 0,
|
||||||
tokens_size: int = 0,
|
tokens_size: int = 0,
|
||||||
n_past: int = 0,
|
n_past: int = 0,
|
||||||
n_ctx: int = 1024,
|
n_ctx: int = 1024,
|
||||||
n_predict: int = 128,
|
n_predict: int = 128,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
top_p: float = .9,
|
top_p: float = .9,
|
||||||
temp: float = .1,
|
temp: float = .1,
|
||||||
n_batch: int = 8,
|
n_batch: int = 8,
|
||||||
repeat_penalty: float = 1.2,
|
repeat_penalty: float = 1.2,
|
||||||
repeat_last_n: int = 10,
|
repeat_last_n: int = 10,
|
||||||
context_erase: float = .5,
|
context_erase: float = .5,
|
||||||
streaming: bool = False) -> str:
|
streaming: bool = True) -> str:
|
||||||
"""
|
"""
|
||||||
Generate response from model from a prompt.
|
Generate response from model from a prompt.
|
||||||
|
|
||||||
@ -237,6 +239,82 @@ class LLModel:
|
|||||||
print()
|
print()
|
||||||
return stream_processor.output
|
return stream_processor.output
|
||||||
|
|
||||||
|
def generator(self,
|
||||||
|
prompt: str,
|
||||||
|
logits_size: int = 0,
|
||||||
|
tokens_size: int = 0,
|
||||||
|
n_past: int = 0,
|
||||||
|
n_ctx: int = 1024,
|
||||||
|
n_predict: int = 128,
|
||||||
|
top_k: int = 40,
|
||||||
|
top_p: float = .9,
|
||||||
|
temp: float = .1,
|
||||||
|
n_batch: int = 8,
|
||||||
|
repeat_penalty: float = 1.2,
|
||||||
|
repeat_last_n: int = 10,
|
||||||
|
context_erase: float = .5) -> str:
|
||||||
|
|
||||||
|
# Symbol to terminate from generator
|
||||||
|
TERMINATING_SYMBOL = "#TERMINATE#"
|
||||||
|
|
||||||
|
output_queue = queue.Queue()
|
||||||
|
|
||||||
|
prompt = prompt.encode('utf-8')
|
||||||
|
prompt = ctypes.c_char_p(prompt)
|
||||||
|
|
||||||
|
context = LLModelPromptContext(
|
||||||
|
logits_size=logits_size,
|
||||||
|
tokens_size=tokens_size,
|
||||||
|
n_past=n_past,
|
||||||
|
n_ctx=n_ctx,
|
||||||
|
n_predict=n_predict,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
temp=temp,
|
||||||
|
n_batch=n_batch,
|
||||||
|
repeat_penalty=repeat_penalty,
|
||||||
|
repeat_last_n=repeat_last_n,
|
||||||
|
context_erase=context_erase
|
||||||
|
)
|
||||||
|
|
||||||
|
# Put response tokens into an output queue
|
||||||
|
def _generator_response_callback(token_id, response):
|
||||||
|
output_queue.put(response.decode('utf-8', 'replace'))
|
||||||
|
return True
|
||||||
|
|
||||||
|
def run_llmodel_prompt(model,
|
||||||
|
prompt,
|
||||||
|
prompt_callback,
|
||||||
|
response_callback,
|
||||||
|
recalculate_callback,
|
||||||
|
context):
|
||||||
|
llmodel.llmodel_prompt(model,
|
||||||
|
prompt,
|
||||||
|
prompt_callback,
|
||||||
|
response_callback,
|
||||||
|
recalculate_callback,
|
||||||
|
context)
|
||||||
|
output_queue.put(TERMINATING_SYMBOL)
|
||||||
|
|
||||||
|
|
||||||
|
# Kick off llmodel_prompt in separate thread so we can return generator
|
||||||
|
# immediately
|
||||||
|
thread = threading.Thread(target=run_llmodel_prompt,
|
||||||
|
args=(self.model,
|
||||||
|
prompt,
|
||||||
|
PromptCallback(self._prompt_callback),
|
||||||
|
ResponseCallback(_generator_response_callback),
|
||||||
|
RecalculateCallback(self._recalculate_callback),
|
||||||
|
context))
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
# Generator
|
||||||
|
while True:
|
||||||
|
response = output_queue.get()
|
||||||
|
if response == TERMINATING_SYMBOL:
|
||||||
|
break
|
||||||
|
yield response
|
||||||
|
|
||||||
# Empty prompt callback
|
# Empty prompt callback
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _prompt_callback(token_id):
|
def _prompt_callback(token_id):
|
||||||
@ -245,7 +323,7 @@ class LLModel:
|
|||||||
# Empty response callback method that just prints response to be collected
|
# Empty response callback method that just prints response to be collected
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _response_callback(token_id, response):
|
def _response_callback(token_id, response):
|
||||||
sys.stdout.write(response.decode('utf-8'))
|
sys.stdout.write(response.decode('utf-8', 'replace'))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Empty recalculate callback
|
# Empty recalculate callback
|
||||||
|
@ -61,7 +61,7 @@ copy_prebuilt_C_lib(SRC_CLIB_DIRECtORY,
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name=package_name,
|
name=package_name,
|
||||||
version="0.3.1",
|
version="0.3.2",
|
||||||
description="Python bindings for GPT4All",
|
description="Python bindings for GPT4All",
|
||||||
author="Richard Guo",
|
author="Richard Guo",
|
||||||
author_email="richard@nomic.ai",
|
author_email="richard@nomic.ai",
|
||||||
|
Loading…
Reference in New Issue
Block a user