mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-08 03:49:10 +00:00
adds a simple cli chat repl (#566)
* adds a simple cli chat repl * add n thread support and append assistant response
This commit is contained in:
@@ -6,6 +6,21 @@ import platform
|
||||
import re
|
||||
import sys
|
||||
|
||||
class DualOutput:
|
||||
def __init__(self, stdout, string_io):
|
||||
self.stdout = stdout
|
||||
self.string_io = string_io
|
||||
|
||||
def write(self, text):
|
||||
self.stdout.write(text)
|
||||
self.string_io.write(text)
|
||||
|
||||
def flush(self):
|
||||
# It's a good idea to also define a flush method that flushes both
|
||||
# outputs, as sys.stdout is expected to have this method.
|
||||
self.stdout.flush()
|
||||
self.string_io.flush()
|
||||
|
||||
# TODO: provide a config file to make this more robust
|
||||
LLMODEL_PATH = os.path.join("llmodel_DO_NOT_MODIFY", "build").replace("\\", "\\\\")
|
||||
|
||||
@@ -81,6 +96,15 @@ llmodel.llmodel_prompt.argtypes = [ctypes.c_void_p,
|
||||
RecalculateCallback,
|
||||
ctypes.POINTER(LLModelPromptContext)]
|
||||
|
||||
llmodel.llmodel_prompt.restype = None
|
||||
|
||||
llmodel.llmodel_setThreadCount.argtypes = [ctypes.c_void_p, ctypes.c_int32]
|
||||
llmodel.llmodel_setThreadCount.restype = None
|
||||
|
||||
llmodel.llmodel_threadCount.argtypes = [ctypes.c_void_p]
|
||||
llmodel.llmodel_threadCount.restype = ctypes.c_int32
|
||||
|
||||
|
||||
class LLModel:
|
||||
"""
|
||||
Base class and universal wrapper for GPT4All language models
|
||||
@@ -125,6 +149,18 @@ class LLModel:
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def set_thread_count(self, n_threads):
|
||||
if not llmodel.llmodel_isModelLoaded(self.model):
|
||||
raise Exception("Model not loaded")
|
||||
llmodel.llmodel_setThreadCount(self.model, n_threads)
|
||||
|
||||
def thread_count(self):
|
||||
if not llmodel.llmodel_isModelLoaded(self.model):
|
||||
raise Exception("Model not loaded")
|
||||
return llmodel.llmodel_threadCount(self.model)
|
||||
|
||||
|
||||
def generate(self,
|
||||
prompt: str,
|
||||
logits_size: int = 0,
|
||||
@@ -138,7 +174,8 @@ class LLModel:
|
||||
n_batch: int = 8,
|
||||
repeat_penalty: float = 1.2,
|
||||
repeat_last_n: int = 10,
|
||||
context_erase: float = .5) -> str:
|
||||
context_erase: float = .5,
|
||||
std_passthrough: bool = False) -> str:
|
||||
"""
|
||||
Generate response from model from a prompt.
|
||||
|
||||
@@ -164,7 +201,10 @@ class LLModel:
|
||||
# Change stdout to StringIO so we can collect response
|
||||
old_stdout = sys.stdout
|
||||
collect_response = StringIO()
|
||||
sys.stdout = collect_response
|
||||
if std_passthrough:
|
||||
sys.stdout = DualOutput(old_stdout, collect_response)
|
||||
else:
|
||||
sys.stdout = collect_response
|
||||
|
||||
context = LLModelPromptContext(
|
||||
logits_size=logits_size,
|
||||
@@ -222,7 +262,7 @@ class GPTJModel(LLModel):
|
||||
self.model = llmodel.llmodel_gptj_create()
|
||||
|
||||
def __del__(self):
|
||||
if self.model is not None:
|
||||
if self.model is not None and llmodel is not None:
|
||||
llmodel.llmodel_gptj_destroy(self.model)
|
||||
super().__del__()
|
||||
|
||||
@@ -236,7 +276,7 @@ class LlamaModel(LLModel):
|
||||
self.model = llmodel.llmodel_llama_create()
|
||||
|
||||
def __del__(self):
|
||||
if self.model is not None:
|
||||
if self.model is not None and llmodel is not None:
|
||||
llmodel.llmodel_llama_destroy(self.model)
|
||||
super().__del__()
|
||||
|
||||
@@ -250,6 +290,6 @@ class MPTModel(LLModel):
|
||||
self.model = llmodel.llmodel_mpt_create()
|
||||
|
||||
def __del__(self):
|
||||
if self.model is not None:
|
||||
if self.model is not None and llmodel is not None:
|
||||
llmodel.llmodel_mpt_destroy(self.model)
|
||||
super().__del__()
|
||||
|
Reference in New Issue
Block a user