mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-23 05:58:48 +00:00
adds a simple cli chat repl (#566)
* adds a simple cli chat repl * add n thread support and append assistant response
This commit is contained in:
parent
95a4516844
commit
d4861030b7
118
gpt4all-bindings/cli/app.py
Normal file
118
gpt4all-bindings/cli/app.py
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
import sys
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
from gpt4all import GPT4All
|
||||||
|
|
||||||
|
MESSAGES = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Hello there."},
|
||||||
|
{"role": "assistant", "content": "Hi, how can I help you?"},
|
||||||
|
]
|
||||||
|
|
||||||
|
SPECIAL_COMMANDS = {
|
||||||
|
"/reset": lambda messages: messages.clear(),
|
||||||
|
"/exit": lambda _: sys.exit(),
|
||||||
|
"/clear": lambda _: print("\n" * 100),
|
||||||
|
"/help": lambda _: print("Special commands: /reset, /exit, /help and /clear"),
|
||||||
|
}
|
||||||
|
|
||||||
|
VERSION = "0.1.0"
|
||||||
|
|
||||||
|
CLI_START_MESSAGE = f"""
|
||||||
|
|
||||||
|
██████ ██████ ████████ ██ ██ █████ ██ ██
|
||||||
|
██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
||||||
|
██ ███ ██████ ██ ███████ ███████ ██ ██
|
||||||
|
██ ██ ██ ██ ██ ██ ██ ██ ██
|
||||||
|
██████ ██ ██ ██ ██ ██ ███████ ███████
|
||||||
|
|
||||||
|
|
||||||
|
Welcome to the GPT4All CLI! Version {VERSION}
|
||||||
|
Type /help for special commands.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _cli_override_response_callback(token_id, response):
|
||||||
|
resp = response.decode("utf-8")
|
||||||
|
print(resp, end="", flush=True)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# create typer app
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def repl(
|
||||||
|
model: Annotated[
|
||||||
|
str,
|
||||||
|
typer.Option("--model", "-m", help="Model to use for chatbot"),
|
||||||
|
] = "ggml-gpt4all-j-v1.3-groovy",
|
||||||
|
n_threads: Annotated[
|
||||||
|
int,
|
||||||
|
typer.Option("--n-threads", "-t", help="Number of threads to use for chatbot"),
|
||||||
|
] = 4,
|
||||||
|
):
|
||||||
|
gpt4all_instance = GPT4All(model)
|
||||||
|
|
||||||
|
# if threads are passed, set them
|
||||||
|
if n_threads != 4:
|
||||||
|
num_threads = gpt4all_instance.model.thread_count()
|
||||||
|
print(f"\nAdjusted: {num_threads} →", end="")
|
||||||
|
|
||||||
|
# set number of threads
|
||||||
|
gpt4all_instance.model.set_thread_count(n_threads)
|
||||||
|
|
||||||
|
num_threads = gpt4all_instance.model.thread_count()
|
||||||
|
print(f" {num_threads} threads", end="", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
# overwrite _response_callback on model
|
||||||
|
gpt4all_instance.model._response_callback = _cli_override_response_callback
|
||||||
|
|
||||||
|
print(CLI_START_MESSAGE)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
message = input(" ⇢ ")
|
||||||
|
|
||||||
|
# Check if special command and take action
|
||||||
|
if message in SPECIAL_COMMANDS:
|
||||||
|
SPECIAL_COMMANDS[message](MESSAGES)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# if regular message, append to messages
|
||||||
|
MESSAGES.append({"role": "user", "content": message})
|
||||||
|
|
||||||
|
# execute chat completion and ignore the full response since
|
||||||
|
# we are outputting it incrementally
|
||||||
|
full_response = gpt4all_instance.chat_completion(
|
||||||
|
MESSAGES,
|
||||||
|
# preferential kwargs for chat ux
|
||||||
|
logits_size=0,
|
||||||
|
tokens_size=0,
|
||||||
|
n_past=0,
|
||||||
|
n_ctx=0,
|
||||||
|
n_predict=200,
|
||||||
|
top_k=40,
|
||||||
|
top_p=0.9,
|
||||||
|
temp=0.9,
|
||||||
|
n_batch=9,
|
||||||
|
repeat_penalty=1.1,
|
||||||
|
repeat_last_n=64,
|
||||||
|
context_erase=0.0,
|
||||||
|
# required kwargs for cli ux (incremental response)
|
||||||
|
verbose=False,
|
||||||
|
std_passthrough=True,
|
||||||
|
)
|
||||||
|
# record assistant's response to messages
|
||||||
|
MESSAGES.append(full_response.get("choices")[0].get("message"))
|
||||||
|
print() # newline before next prompt
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def version():
|
||||||
|
print("gpt4all-cli v0.1.0")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
@ -6,6 +6,21 @@ import platform
|
|||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
class DualOutput:
|
||||||
|
def __init__(self, stdout, string_io):
|
||||||
|
self.stdout = stdout
|
||||||
|
self.string_io = string_io
|
||||||
|
|
||||||
|
def write(self, text):
|
||||||
|
self.stdout.write(text)
|
||||||
|
self.string_io.write(text)
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
# It's a good idea to also define a flush method that flushes both
|
||||||
|
# outputs, as sys.stdout is expected to have this method.
|
||||||
|
self.stdout.flush()
|
||||||
|
self.string_io.flush()
|
||||||
|
|
||||||
# TODO: provide a config file to make this more robust
|
# TODO: provide a config file to make this more robust
|
||||||
LLMODEL_PATH = os.path.join("llmodel_DO_NOT_MODIFY", "build").replace("\\", "\\\\")
|
LLMODEL_PATH = os.path.join("llmodel_DO_NOT_MODIFY", "build").replace("\\", "\\\\")
|
||||||
|
|
||||||
@ -81,6 +96,15 @@ llmodel.llmodel_prompt.argtypes = [ctypes.c_void_p,
|
|||||||
RecalculateCallback,
|
RecalculateCallback,
|
||||||
ctypes.POINTER(LLModelPromptContext)]
|
ctypes.POINTER(LLModelPromptContext)]
|
||||||
|
|
||||||
|
llmodel.llmodel_prompt.restype = None
|
||||||
|
|
||||||
|
llmodel.llmodel_setThreadCount.argtypes = [ctypes.c_void_p, ctypes.c_int32]
|
||||||
|
llmodel.llmodel_setThreadCount.restype = None
|
||||||
|
|
||||||
|
llmodel.llmodel_threadCount.argtypes = [ctypes.c_void_p]
|
||||||
|
llmodel.llmodel_threadCount.restype = ctypes.c_int32
|
||||||
|
|
||||||
|
|
||||||
class LLModel:
|
class LLModel:
|
||||||
"""
|
"""
|
||||||
Base class and universal wrapper for GPT4All language models
|
Base class and universal wrapper for GPT4All language models
|
||||||
@ -125,6 +149,18 @@ class LLModel:
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def set_thread_count(self, n_threads):
|
||||||
|
if not llmodel.llmodel_isModelLoaded(self.model):
|
||||||
|
raise Exception("Model not loaded")
|
||||||
|
llmodel.llmodel_setThreadCount(self.model, n_threads)
|
||||||
|
|
||||||
|
def thread_count(self):
|
||||||
|
if not llmodel.llmodel_isModelLoaded(self.model):
|
||||||
|
raise Exception("Model not loaded")
|
||||||
|
return llmodel.llmodel_threadCount(self.model)
|
||||||
|
|
||||||
|
|
||||||
def generate(self,
|
def generate(self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
logits_size: int = 0,
|
logits_size: int = 0,
|
||||||
@ -138,7 +174,8 @@ class LLModel:
|
|||||||
n_batch: int = 8,
|
n_batch: int = 8,
|
||||||
repeat_penalty: float = 1.2,
|
repeat_penalty: float = 1.2,
|
||||||
repeat_last_n: int = 10,
|
repeat_last_n: int = 10,
|
||||||
context_erase: float = .5) -> str:
|
context_erase: float = .5,
|
||||||
|
std_passthrough: bool = False) -> str:
|
||||||
"""
|
"""
|
||||||
Generate response from model from a prompt.
|
Generate response from model from a prompt.
|
||||||
|
|
||||||
@ -164,7 +201,10 @@ class LLModel:
|
|||||||
# Change stdout to StringIO so we can collect response
|
# Change stdout to StringIO so we can collect response
|
||||||
old_stdout = sys.stdout
|
old_stdout = sys.stdout
|
||||||
collect_response = StringIO()
|
collect_response = StringIO()
|
||||||
sys.stdout = collect_response
|
if std_passthrough:
|
||||||
|
sys.stdout = DualOutput(old_stdout, collect_response)
|
||||||
|
else:
|
||||||
|
sys.stdout = collect_response
|
||||||
|
|
||||||
context = LLModelPromptContext(
|
context = LLModelPromptContext(
|
||||||
logits_size=logits_size,
|
logits_size=logits_size,
|
||||||
@ -222,7 +262,7 @@ class GPTJModel(LLModel):
|
|||||||
self.model = llmodel.llmodel_gptj_create()
|
self.model = llmodel.llmodel_gptj_create()
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if self.model is not None:
|
if self.model is not None and llmodel is not None:
|
||||||
llmodel.llmodel_gptj_destroy(self.model)
|
llmodel.llmodel_gptj_destroy(self.model)
|
||||||
super().__del__()
|
super().__del__()
|
||||||
|
|
||||||
@ -236,7 +276,7 @@ class LlamaModel(LLModel):
|
|||||||
self.model = llmodel.llmodel_llama_create()
|
self.model = llmodel.llmodel_llama_create()
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if self.model is not None:
|
if self.model is not None and llmodel is not None:
|
||||||
llmodel.llmodel_llama_destroy(self.model)
|
llmodel.llmodel_llama_destroy(self.model)
|
||||||
super().__del__()
|
super().__del__()
|
||||||
|
|
||||||
@ -250,6 +290,6 @@ class MPTModel(LLModel):
|
|||||||
self.model = llmodel.llmodel_mpt_create()
|
self.model = llmodel.llmodel_mpt_create()
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if self.model is not None:
|
if self.model is not None and llmodel is not None:
|
||||||
llmodel.llmodel_mpt_destroy(self.model)
|
llmodel.llmodel_mpt_destroy(self.model)
|
||||||
super().__del__()
|
super().__del__()
|
||||||
|
Loading…
Reference in New Issue
Block a user