From d4861030b778da6db59d21d2927a4aba4f9f1f43 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 16 May 2023 13:47:54 -0700 Subject: [PATCH] adds a simple cli chat repl (#566) * adds a simple cli chat repl * add n thread support and append assistant response --- gpt4all-bindings/cli/app.py | 118 +++++++++++++++++++ gpt4all-bindings/python/gpt4all/pyllmodel.py | 50 +++++++- 2 files changed, 163 insertions(+), 5 deletions(-) create mode 100644 gpt4all-bindings/cli/app.py diff --git a/gpt4all-bindings/cli/app.py b/gpt4all-bindings/cli/app.py new file mode 100644 index 00000000..953ed417 --- /dev/null +++ b/gpt4all-bindings/cli/app.py @@ -0,0 +1,118 @@ +import sys +import typer + +from typing_extensions import Annotated +from gpt4all import GPT4All + +MESSAGES = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello there."}, + {"role": "assistant", "content": "Hi, how can I help you?"}, +] + +SPECIAL_COMMANDS = { + "/reset": lambda messages: messages.clear(), + "/exit": lambda _: sys.exit(), + "/clear": lambda _: print("\n" * 100), + "/help": lambda _: print("Special commands: /reset, /exit, /help and /clear"), +} + +VERSION = "0.1.0" + +CLI_START_MESSAGE = f""" + + ██████ ██████ ████████ ██ ██ █████ ██ ██ +██ ██ ██ ██ ██ ██ ██ ██ ██ ██ +██ ███ ██████ ██ ███████ ███████ ██ ██ +██ ██ ██ ██ ██ ██ ██ ██ ██ + ██████ ██ ██ ██ ██ ██ ███████ ███████ + + +Welcome to the GPT4All CLI! Version {VERSION} +Type /help for special commands. + +""" + +def _cli_override_response_callback(token_id, response): + resp = response.decode("utf-8") + print(resp, end="", flush=True) + return True + + +# create typer app +app = typer.Typer() + +@app.command() +def repl( + model: Annotated[ + str, + typer.Option("--model", "-m", help="Model to use for chatbot"), + ] = "ggml-gpt4all-j-v1.3-groovy", + n_threads: Annotated[ + int, + typer.Option("--n-threads", "-t", help="Number of threads to use for chatbot"), + ] = 4, +): + gpt4all_instance = GPT4All(model) + + # if threads are passed, set them + if n_threads != 4: + num_threads = gpt4all_instance.model.thread_count() + print(f"\nAdjusted: {num_threads} →", end="") + + # set number of threads + gpt4all_instance.model.set_thread_count(n_threads) + + num_threads = gpt4all_instance.model.thread_count() + print(f" {num_threads} threads", end="", flush=True) + + + # overwrite _response_callback on model + gpt4all_instance.model._response_callback = _cli_override_response_callback + + print(CLI_START_MESSAGE) + + while True: + message = input(" ⇢ ") + + # Check if special command and take action + if message in SPECIAL_COMMANDS: + SPECIAL_COMMANDS[message](MESSAGES) + continue + + # if regular message, append to messages + MESSAGES.append({"role": "user", "content": message}) + + # execute chat completion and ignore the full response since + # we are outputting it incrementally + full_response = gpt4all_instance.chat_completion( + MESSAGES, + # preferential kwargs for chat ux + logits_size=0, + tokens_size=0, + n_past=0, + n_ctx=0, + n_predict=200, + top_k=40, + top_p=0.9, + temp=0.9, + n_batch=9, + repeat_penalty=1.1, + repeat_last_n=64, + context_erase=0.0, + # required kwargs for cli ux (incremental response) + verbose=False, + std_passthrough=True, + ) + # record assistant's response to messages + MESSAGES.append(full_response.get("choices")[0].get("message")) + print() # newline before next prompt + + +@app.command() +def version(): + print("gpt4all-cli v0.1.0") + + +if __name__ == "__main__": + app() diff --git a/gpt4all-bindings/python/gpt4all/pyllmodel.py b/gpt4all-bindings/python/gpt4all/pyllmodel.py index d194c234..a1f29f4d 100644 --- a/gpt4all-bindings/python/gpt4all/pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/pyllmodel.py @@ -6,6 +6,21 @@ import platform import re import sys +class DualOutput: + def __init__(self, stdout, string_io): + self.stdout = stdout + self.string_io = string_io + + def write(self, text): + self.stdout.write(text) + self.string_io.write(text) + + def flush(self): + # It's a good idea to also define a flush method that flushes both + # outputs, as sys.stdout is expected to have this method. + self.stdout.flush() + self.string_io.flush() + # TODO: provide a config file to make this more robust LLMODEL_PATH = os.path.join("llmodel_DO_NOT_MODIFY", "build").replace("\\", "\\\\") @@ -81,6 +96,15 @@ llmodel.llmodel_prompt.argtypes = [ctypes.c_void_p, RecalculateCallback, ctypes.POINTER(LLModelPromptContext)] +llmodel.llmodel_prompt.restype = None + +llmodel.llmodel_setThreadCount.argtypes = [ctypes.c_void_p, ctypes.c_int32] +llmodel.llmodel_setThreadCount.restype = None + +llmodel.llmodel_threadCount.argtypes = [ctypes.c_void_p] +llmodel.llmodel_threadCount.restype = ctypes.c_int32 + + class LLModel: """ Base class and universal wrapper for GPT4All language models @@ -125,6 +149,18 @@ class LLModel: else: return False + + def set_thread_count(self, n_threads): + if not llmodel.llmodel_isModelLoaded(self.model): + raise Exception("Model not loaded") + llmodel.llmodel_setThreadCount(self.model, n_threads) + + def thread_count(self): + if not llmodel.llmodel_isModelLoaded(self.model): + raise Exception("Model not loaded") + return llmodel.llmodel_threadCount(self.model) + + def generate(self, prompt: str, logits_size: int = 0, @@ -138,7 +174,8 @@ class LLModel: n_batch: int = 8, repeat_penalty: float = 1.2, repeat_last_n: int = 10, - context_erase: float = .5) -> str: + context_erase: float = .5, + std_passthrough: bool = False) -> str: """ Generate response from model from a prompt. @@ -164,7 +201,10 @@ class LLModel: # Change stdout to StringIO so we can collect response old_stdout = sys.stdout collect_response = StringIO() - sys.stdout = collect_response + if std_passthrough: + sys.stdout = DualOutput(old_stdout, collect_response) + else: + sys.stdout = collect_response context = LLModelPromptContext( logits_size=logits_size, @@ -222,7 +262,7 @@ class GPTJModel(LLModel): self.model = llmodel.llmodel_gptj_create() def __del__(self): - if self.model is not None: + if self.model is not None and llmodel is not None: llmodel.llmodel_gptj_destroy(self.model) super().__del__() @@ -236,7 +276,7 @@ class LlamaModel(LLModel): self.model = llmodel.llmodel_llama_create() def __del__(self): - if self.model is not None: + if self.model is not None and llmodel is not None: llmodel.llmodel_llama_destroy(self.model) super().__del__() @@ -250,6 +290,6 @@ class MPTModel(LLModel): self.model = llmodel.llmodel_mpt_create() def __del__(self): - if self.model is not None: + if self.model is not None and llmodel is not None: llmodel.llmodel_mpt_destroy(self.model) super().__del__()