Python Bindings: Improved unit tests, documentation and unification of API (#1090)

* Makefiles, black, isort

* Black and isort

* unit tests and generation method

* chat context provider

* context does not reset

* Current state

* Fixup

* Python bindings with unit tests

* GPT4All Python Bindings: chat contexts, tests

* New python bindings and backend fixes

* Black and Isort

* Documentation error

* preserved n_predict for backwords compat with langchain

---------

Co-authored-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
Andriy Mulyar
2023-06-30 16:02:02 -04:00
committed by GitHub
parent 40a3faeb05
commit 46a0762bd5
15 changed files with 437 additions and 407 deletions

View File

@@ -1,2 +1,2 @@
from .pyllmodel import LLModel # noqa
from .gpt4all import GPT4All # noqa
from .gpt4all import GPT4All # noqa
from .pyllmodel import LLModel # noqa

View File

@@ -2,9 +2,10 @@
Python only API for running all GPT4All models.
"""
import os
from pathlib import Path
import time
from typing import Dict, List
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, Iterable, List, Union
import requests
from tqdm import tqdm
@@ -15,14 +16,14 @@ from . import pyllmodel
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\")
class GPT4All():
"""Python API for retrieving and interacting with GPT4All models.
Attributes:
model: Pointer to underlying C model.
class GPT4All:
"""
Python class that handles instantiation, downloading, generation and chat with GPT4All models.
"""
def __init__(self, model_name: str, model_path: str = None, model_type: str = None, allow_download = True, n_threads = None):
def __init__(
self, model_name: str, model_path: str = None, model_type: str = None, allow_download=True, n_threads=None
):
"""
Constructor
@@ -32,7 +33,7 @@ class GPT4All():
Default is None, in which case models will be stored in `~/.cache/gpt4all/`.
model_type: Model architecture. This argument currently does not have any functionality and is just used as
descriptive identifier for user. Default is None.
allow_download: Allow API to download models from gpt4all.io. Default is True.
allow_download: Allow API to download models from gpt4all.io. Default is True.
n_threads: number of CPU threads used by GPT4All. Default is None, than the number of threads are determined automatically.
"""
self.model_type = model_type
@@ -41,11 +42,14 @@ class GPT4All():
model_dest = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download)
self.model.load_model(model_dest)
# Set n_threads
if n_threads != None:
if n_threads is not None:
self.model.set_thread_count(n_threads)
self._is_chat_session_activated = False
self.current_chat_session = []
@staticmethod
def list_models():
def list_models() -> Dict:
"""
Fetch model list from https://gpt4all.io/models/models.json.
@@ -55,8 +59,9 @@ class GPT4All():
return requests.get("https://gpt4all.io/models/models.json").json()
@staticmethod
def retrieve_model(model_name: str, model_path: str = None, allow_download: bool = True,
verbose: bool = True) -> str:
def retrieve_model(
model_name: str, model_path: str = None, allow_download: bool = True, verbose: bool = True
) -> str:
"""
Find model file, and if it doesn't exist, download the model.
@@ -78,8 +83,10 @@ class GPT4All():
try:
os.makedirs(DEFAULT_MODEL_DIRECTORY, exist_ok=True)
except OSError as exc:
raise ValueError(f"Failed to create model download directory at {DEFAULT_MODEL_DIRECTORY}: {exc}. "
"Please specify model_path.")
raise ValueError(
f"Failed to create model download directory at {DEFAULT_MODEL_DIRECTORY}: {exc}. "
"Please specify model_path."
)
model_path = DEFAULT_MODEL_DIRECTORY
else:
model_path = model_path.replace("\\", "\\\\")
@@ -108,7 +115,7 @@ class GPT4All():
raise ValueError(f"Model filename not in model list: {model_filename}")
url = selected_model.pop('url', None)
return GPT4All.download_model(model_filename, model_path, verbose = verbose, url=url)
return GPT4All.download_model(model_filename, model_path, verbose=verbose, url=url)
else:
raise ValueError("Failed to retrieve model")
@@ -126,6 +133,7 @@ class GPT4All():
Returns:
Model file destination.
"""
def get_download_url(model_filename):
if url:
return url
@@ -137,7 +145,7 @@ class GPT4All():
response = requests.get(download_url, stream=True)
total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 2 ** 20 # 1 MB
block_size = 2**20 # 1 MB
with tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) as progress_bar:
try:
@@ -154,9 +162,7 @@ class GPT4All():
# Validate download was successful
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
raise RuntimeError(
"An error occurred during download. Downloaded file may not work."
)
raise RuntimeError("An error occurred during download. Downloaded file may not work.")
# Sleep for a little bit so OS can remove file lock
time.sleep(2)
@@ -165,101 +171,83 @@ class GPT4All():
print("Model downloaded at: ", download_path)
return download_path
# TODO: this naming is just confusing now and needs to be deprecated now that we have generator
# Need to better consolidate all these different model response methods
def generate(self, prompt: str, streaming: bool = True, **generate_kwargs) -> str:
def generate(
self,
prompt: str,
max_tokens: int = 200,
temp: float = 0.7,
top_k: int = 40,
top_p: float = 0.1,
repeat_penalty: float = 1.18,
repeat_last_n: int = 64,
n_batch: int = 8,
n_predict: int = None,
streaming: bool = False,
) -> Union[str, Iterable]:
"""
Surfaced method of running generate without accessing model object.
Generate outputs from any GPT4All model.
Args:
prompt: Raw string to be passed to model.
streaming: True if want output streamed to stdout.
**generate_kwargs: Optional kwargs to pass to prompt context.
Returns:
Raw string of generated model response.
"""
return self.model.prompt_model(prompt, streaming=streaming, **generate_kwargs)
def generator(self, prompt: str, **generate_kwargs) -> str:
"""
Surfaced method of running generate without accessing model object.
Args:
prompt: Raw string to be passed to model.
streaming: True if want output streamed to stdout.
**generate_kwargs: Optional kwargs to pass to prompt context.
Returns:
Raw string of generated model response.
"""
return self.model.generator(prompt, **generate_kwargs)
def chat_completion(self,
messages: List[Dict],
default_prompt_header: bool = True,
default_prompt_footer: bool = True,
verbose: bool = True,
streaming: bool = True,
**generate_kwargs) -> dict:
"""
Format list of message dictionaries into a prompt and call model
generate on prompt. Returns a response dictionary with metadata and
generated content.
Args:
messages: List of dictionaries. Each dictionary should have a "role" key
with value of "system", "assistant", or "user" and a "content" key with a
string value. Messages are organized such that "system" messages are at top of prompt,
and "user" and "assistant" messages are displayed in order. Assistant messages get formatted as
"Response: {content}".
default_prompt_header: If True (default), add default prompt header after any system role messages and
before user/assistant role messages.
default_prompt_footer: If True (default), add default footer at end of prompt.
verbose: If True (default), print full prompt and generated response.
streaming: True if want output streamed to stdout.
**generate_kwargs: Optional kwargs to pass to prompt context.
prompt: The prompt for the model the complete.
max_tokens: The maximum number of tokens to generate.
temp: The model temperature. Larger values increase creativity but decrease factuality.
top_k: Randomly sample from the top_k most likely tokens at each generation step. Set this to 1 for greedy decoding.
top_p: Randomly sample at each generation step from the top most likely tokens whose probabilities add up to top_p.
repeat_penalty: Penalize the model for repetition. Higher values result in less repetition.
repeat_last_n: How far in the models generation history to apply the repeat penalty.
n_batch: Number of prompt tokens processed in parallel. Larger values decrease latency but increase resource requirements.
n_predict: Equivalent to max_tokens, exists for backwards compatability.
streaming: If True, this method will instead return a generator that yields tokens as the model generates them.
Returns:
Response dictionary with:
"model": name of model.
"usage": a dictionary with number of full prompt tokens, number of
generated tokens in response, and total tokens.
"choices": List of message dictionary where "content" is generated response and "role" is set
as "assistant". Right now, only one choice is returned by model.
Either the entire completion or a generator that yields the completion token by token.
"""
full_prompt = self._build_prompt(messages,
default_prompt_header=default_prompt_header,
default_prompt_footer=default_prompt_footer)
if verbose:
print(full_prompt)
generate_kwargs = locals()
generate_kwargs.pop('self')
generate_kwargs.pop('max_tokens')
generate_kwargs.pop('streaming')
generate_kwargs['n_predict'] = max_tokens
if n_predict is not None:
generate_kwargs['n_predict'] = n_predict
response = self.model.prompt_model(full_prompt, streaming=streaming, **generate_kwargs)
if streaming and self._is_chat_session_activated:
raise NotImplementedError("Streaming tokens in a chat session is not currently supported.")
if verbose and not streaming:
print(response)
if self._is_chat_session_activated:
self.current_chat_session.append({"role": "user", "content": prompt})
generate_kwargs['prompt'] = self._format_chat_prompt_template(messages=self.current_chat_session)
generate_kwargs['reset_context'] = len(self.current_chat_session) == 1
else:
generate_kwargs['reset_context'] = True
response_dict = {
"model": self.model.model_name,
"usage": {"prompt_tokens": len(full_prompt),
"completion_tokens": len(response),
"total_tokens": len(full_prompt) + len(response)},
"choices": [
{
"message": {
"role": "assistant",
"content": response
}
}
]
}
if streaming:
return self.model.prompt_model_streaming(**generate_kwargs)
return response_dict
output = self.model.prompt_model(**generate_kwargs)
@staticmethod
def _build_prompt(messages: List[Dict],
default_prompt_header=True,
default_prompt_footer=True) -> str:
if self._is_chat_session_activated:
self.current_chat_session.append({"role": "assistant", "content": output})
return output
@contextmanager
def chat_session(self):
'''
Context manager to hold an inference optimized chat session with a GPT4All model.
'''
# Code to acquire resource, e.g.:
self._is_chat_session_activated = True
self._current_chat_session = []
try:
yield self
finally:
# Code to release resource, e.g.:
self._is_chat_session_activated = False
self._current_chat_session = []
def _format_chat_prompt_template(
self, messages: List[Dict], default_prompt_header=True, default_prompt_footer=True
) -> str:
"""
Helper method for building a prompt using template from list of messages.
@@ -269,37 +257,20 @@ class GPT4All():
string value. Messages are organized such that "system" messages are at top of prompt,
and "user" and "assistant" messages are displayed in order. Assistant messages get formatted as
"Response: {content}".
default_prompt_header: If True (default), add default prompt header after any system role messages and
before user/assistant role messages.
default_prompt_footer: If True (default), add default footer at end of prompt.
Returns:
Formatted prompt.
"""
full_prompt = ""
for message in messages:
if message["role"] == "system":
system_message = message["content"] + "\n"
full_prompt += system_message
if default_prompt_header:
full_prompt += """### Instruction:
The prompt below is a question to answer, a task to complete, or a conversation
to respond to; decide which and write an appropriate response.
\n### Prompt: """
for message in messages:
if message["role"] == "user":
user_message = "\n" + message["content"]
user_message = "### Human: \n" + message["content"] + "\n### Assistant:\n"
full_prompt += user_message
if message["role"] == "assistant":
assistant_message = "\n### Response: " + message["content"]
assistant_message = message["content"] + '\n'
full_prompt += assistant_message
if default_prompt_footer:
full_prompt += "\n### Response:"
return full_prompt

View File

@@ -1,4 +1,3 @@
import pkg_resources
import ctypes
import os
import platform
@@ -7,6 +6,10 @@ import re
import subprocess
import sys
import threading
from typing import Iterable
import pkg_resources
class DualStreamProcessor:
def __init__(self, stream=None):
@@ -19,10 +22,12 @@ class DualStreamProcessor:
self.stream.flush()
self.output += text
# TODO: provide a config file to make this more robust
LLMODEL_PATH = os.path.join("llmodel_DO_NOT_MODIFY", "build").replace("\\", "\\\\")
MODEL_LIB_PATH = str(pkg_resources.resource_filename("gpt4all", LLMODEL_PATH)).replace("\\", "\\\\")
def load_llmodel_library():
system = platform.system()
@@ -40,34 +45,40 @@ def load_llmodel_library():
llmodel_file = "libllmodel" + '.' + c_lib_ext
llmodel_dir = str(pkg_resources.resource_filename('gpt4all', \
os.path.join(LLMODEL_PATH, llmodel_file))).replace("\\", "\\\\")
llmodel_dir = str(pkg_resources.resource_filename('gpt4all', os.path.join(LLMODEL_PATH, llmodel_file))).replace(
"\\", "\\\\"
)
llmodel_lib = ctypes.CDLL(llmodel_dir)
return llmodel_lib
llmodel = load_llmodel_library()
class LLModelError(ctypes.Structure):
_fields_ = [("message", ctypes.c_char_p),
("code", ctypes.c_int32)]
_fields_ = [("message", ctypes.c_char_p), ("code", ctypes.c_int32)]
class LLModelPromptContext(ctypes.Structure):
_fields_ = [("logits", ctypes.POINTER(ctypes.c_float)),
("logits_size", ctypes.c_size_t),
("tokens", ctypes.POINTER(ctypes.c_int32)),
("tokens_size", ctypes.c_size_t),
("n_past", ctypes.c_int32),
("n_ctx", ctypes.c_int32),
("n_predict", ctypes.c_int32),
("top_k", ctypes.c_int32),
("top_p", ctypes.c_float),
("temp", ctypes.c_float),
("n_batch", ctypes.c_int32),
("repeat_penalty", ctypes.c_float),
("repeat_last_n", ctypes.c_int32),
("context_erase", ctypes.c_float)]
_fields_ = [
("logits", ctypes.POINTER(ctypes.c_float)),
("logits_size", ctypes.c_size_t),
("tokens", ctypes.POINTER(ctypes.c_int32)),
("tokens_size", ctypes.c_size_t),
("n_past", ctypes.c_int32),
("n_ctx", ctypes.c_int32),
("n_predict", ctypes.c_int32),
("top_k", ctypes.c_int32),
("top_p", ctypes.c_float),
("temp", ctypes.c_float),
("n_batch", ctypes.c_int32),
("repeat_penalty", ctypes.c_float),
("repeat_last_n", ctypes.c_int32),
("context_erase", ctypes.c_float),
]
# Define C function signatures using ctypes
llmodel.llmodel_model_create.argtypes = [ctypes.c_char_p]
@@ -90,12 +101,14 @@ PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32)
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool)
llmodel.llmodel_prompt.argtypes = [ctypes.c_void_p,
ctypes.c_char_p,
PromptCallback,
ResponseCallback,
RecalculateCallback,
ctypes.POINTER(LLModelPromptContext)]
llmodel.llmodel_prompt.argtypes = [
ctypes.c_void_p,
ctypes.c_char_p,
PromptCallback,
ResponseCallback,
RecalculateCallback,
ctypes.POINTER(LLModelPromptContext),
]
llmodel.llmodel_prompt.restype = None
@@ -142,7 +155,6 @@ class LLModel:
else:
raise ValueError("Unable to instantiate model")
def load_model(self, model_path: str) -> bool:
"""
Load model from a file.
@@ -182,21 +194,59 @@ class LLModel:
raise Exception("Model not loaded")
return llmodel.llmodel_threadCount(self.model)
def prompt_model(self,
prompt: str,
logits_size: int = 0,
tokens_size: int = 0,
n_past: int = 0,
n_ctx: int = 1024,
n_predict: int = 128,
top_k: int = 40,
top_p: float = .9,
temp: float = .1,
n_batch: int = 8,
repeat_penalty: float = 1.2,
repeat_last_n: int = 10,
context_erase: float = .5,
streaming: bool = True) -> str:
def _set_context(
self,
n_predict: int = 4096,
top_k: int = 40,
top_p: float = 0.9,
temp: float = 0.1,
n_batch: int = 8,
repeat_penalty: float = 1.2,
repeat_last_n: int = 10,
context_erase: float = 0.75,
reset_context: bool = False,
):
if self.context is None:
self.context = LLModelPromptContext(
logits_size=0,
tokens_size=0,
n_past=0,
n_ctx=0,
n_predict=n_predict,
top_k=top_k,
top_p=top_p,
temp=temp,
n_batch=n_batch,
repeat_penalty=repeat_penalty,
repeat_last_n=repeat_last_n,
context_erase=context_erase,
)
elif reset_context:
self.context.n_past = 0
self.context.n_predict = n_predict
self.context.top_k = top_k
self.context.top_p = top_p
self.context.temp = temp
self.context.n_batch = n_batch
self.context.repeat_penalty = repeat_penalty
self.context.repeat_last_n = repeat_last_n
self.context.context_erase = context_erase
def prompt_model(
self,
prompt: str,
n_predict: int = 4096,
top_k: int = 40,
top_p: float = 0.9,
temp: float = 0.1,
n_batch: int = 8,
repeat_penalty: float = 1.2,
repeat_last_n: int = 10,
context_erase: float = 0.75,
reset_context: bool = False,
streaming=False,
) -> str:
"""
Generate response from model from a prompt.
@@ -211,117 +261,100 @@ class LLModel:
-------
Model response str
"""
prompt = prompt.encode('utf-8')
prompt = ctypes.c_char_p(prompt)
old_stdout = sys.stdout
old_stdout = sys.stdout
stream_processor = DualStreamProcessor()
if streaming:
stream_processor.stream = sys.stdout
sys.stdout = stream_processor
self._set_context(
n_predict=n_predict,
top_k=top_k,
top_p=top_p,
temp=temp,
n_batch=n_batch,
repeat_penalty=repeat_penalty,
repeat_last_n=repeat_last_n,
context_erase=context_erase,
reset_context=reset_context,
)
if self.context is None:
self.context = LLModelPromptContext(
logits_size=logits_size,
tokens_size=tokens_size,
n_past=n_past,
n_ctx=n_ctx,
n_predict=n_predict,
top_k=top_k,
top_p=top_p,
temp=temp,
n_batch=n_batch,
repeat_penalty=repeat_penalty,
repeat_last_n=repeat_last_n,
context_erase=context_erase
)
llmodel.llmodel_prompt(self.model,
prompt,
PromptCallback(self._prompt_callback),
ResponseCallback(self._response_callback),
RecalculateCallback(self._recalculate_callback),
self.context)
llmodel.llmodel_prompt(
self.model,
prompt,
PromptCallback(self._prompt_callback),
ResponseCallback(self._response_callback),
RecalculateCallback(self._recalculate_callback),
self.context,
)
# Revert to old stdout
sys.stdout = old_stdout
# Force new line
print()
return stream_processor.output
def generator(self,
prompt: str,
logits_size: int = 0,
tokens_size: int = 0,
n_past: int = 0,
n_ctx: int = 1024,
n_predict: int = 128,
top_k: int = 40,
top_p: float = .9,
temp: float = .1,
n_batch: int = 8,
repeat_penalty: float = 1.2,
repeat_last_n: int = 10,
context_erase: float = .5) -> str:
def prompt_model_streaming(
self,
prompt: str,
n_predict: int = 4096,
top_k: int = 40,
top_p: float = 0.9,
temp: float = 0.1,
n_batch: int = 8,
repeat_penalty: float = 1.2,
repeat_last_n: int = 10,
context_erase: float = 0.75,
reset_context: bool = False,
) -> Iterable:
# Symbol to terminate from generator
TERMINATING_SYMBOL = "#TERMINATE#"
output_queue = queue.Queue()
prompt = prompt.encode('utf-8')
prompt = ctypes.c_char_p(prompt)
if self.context is None:
self.context = LLModelPromptContext(
logits_size=logits_size,
tokens_size=tokens_size,
n_past=n_past,
n_ctx=n_ctx,
n_predict=n_predict,
top_k=top_k,
top_p=top_p,
temp=temp,
n_batch=n_batch,
repeat_penalty=repeat_penalty,
repeat_last_n=repeat_last_n,
context_erase=context_erase
)
self._set_context(
n_predict=n_predict,
top_k=top_k,
top_p=top_p,
temp=temp,
n_batch=n_batch,
repeat_penalty=repeat_penalty,
repeat_last_n=repeat_last_n,
context_erase=context_erase,
reset_context=reset_context,
)
# Put response tokens into an output queue
def _generator_response_callback(token_id, response):
output_queue.put(response.decode('utf-8', 'replace'))
return True
def run_llmodel_prompt(model,
prompt,
prompt_callback,
response_callback,
recalculate_callback,
context):
llmodel.llmodel_prompt(model,
prompt,
prompt_callback,
response_callback,
recalculate_callback,
context)
def run_llmodel_prompt(model, prompt, prompt_callback, response_callback, recalculate_callback, context):
llmodel.llmodel_prompt(model, prompt, prompt_callback, response_callback, recalculate_callback, context)
output_queue.put(TERMINATING_SYMBOL)
# Kick off llmodel_prompt in separate thread so we can return generator
# immediately
thread = threading.Thread(target=run_llmodel_prompt,
args=(self.model,
prompt,
PromptCallback(self._prompt_callback),
ResponseCallback(_generator_response_callback),
RecalculateCallback(self._recalculate_callback),
self.context))
thread = threading.Thread(
target=run_llmodel_prompt,
args=(
self.model,
prompt,
PromptCallback(self._prompt_callback),
ResponseCallback(_generator_response_callback),
RecalculateCallback(self._recalculate_callback),
self.context,
),
)
thread.start()
# Generator

View File

@@ -0,0 +1,54 @@
import sys
from io import StringIO
from gpt4all import GPT4All
def test_inference():
model = GPT4All(model_name='orca-mini-3b.ggmlv3.q4_0.bin')
output_1 = model.generate('hello', top_k=1)
with model.chat_session():
response = model.generate(prompt='hello', top_k=1)
response = model.generate(prompt='write me a short poem', top_k=1)
response = model.generate(prompt='thank you', top_k=1)
print(model.current_chat_session)
output_2 = model.generate('hello', top_k=1)
assert output_1 == output_2
tokens = []
for token in model.generate('hello', streaming=True):
tokens.append(token)
assert len(tokens) > 0
with model.chat_session():
try:
response = model.generate(prompt='hello', top_k=1, streaming=True)
assert False
except NotImplementedError:
assert True
def test_inference_hparams():
model = GPT4All(model_name='orca-mini-3b.ggmlv3.q4_0.bin')
output = model.generate("The capital of france is ", max_tokens=3)
assert 'Paris' in output
def test_inference_falcon():
model = GPT4All(model_name='ggml-model-gpt4all-falcon-q4_0.bin')
prompt = 'hello'
output = model.generate(prompt)
assert len(output) > 0
def test_inference_mpt():
model = GPT4All(model_name='ggml-mpt-7b-chat.bin')
prompt = 'hello'
output = model.generate(prompt)
assert len(output) > 0