fix prompt context so it's preserved in class

This commit is contained in:
Richard Guo 2023-06-12 22:38:50 -04:00
parent 85964a7635
commit a99cc34efb

View File

@ -125,6 +125,7 @@ class LLModel:
def __init__(self): def __init__(self):
self.model = None self.model = None
self.model_name = None self.model_name = None
self.context = None
def __del__(self): def __del__(self):
if self.model is not None: if self.model is not None:
@ -211,27 +212,29 @@ class LLModel:
sys.stdout = stream_processor sys.stdout = stream_processor
context = LLModelPromptContext(
logits_size=logits_size, if self.context is None:
tokens_size=tokens_size, self.context = LLModelPromptContext(
n_past=n_past, logits_size=logits_size,
n_ctx=n_ctx, tokens_size=tokens_size,
n_predict=n_predict, n_past=n_past,
top_k=top_k, n_ctx=n_ctx,
top_p=top_p, n_predict=n_predict,
temp=temp, top_k=top_k,
n_batch=n_batch, top_p=top_p,
repeat_penalty=repeat_penalty, temp=temp,
repeat_last_n=repeat_last_n, n_batch=n_batch,
context_erase=context_erase repeat_penalty=repeat_penalty,
) repeat_last_n=repeat_last_n,
context_erase=context_erase
)
llmodel.llmodel_prompt(self.model, llmodel.llmodel_prompt(self.model,
prompt, prompt,
PromptCallback(self._prompt_callback), PromptCallback(self._prompt_callback),
ResponseCallback(self._response_callback), ResponseCallback(self._response_callback),
RecalculateCallback(self._recalculate_callback), RecalculateCallback(self._recalculate_callback),
context) self.context)
# Revert to old stdout # Revert to old stdout
sys.stdout = old_stdout sys.stdout = old_stdout
@ -262,20 +265,21 @@ class LLModel:
prompt = prompt.encode('utf-8') prompt = prompt.encode('utf-8')
prompt = ctypes.c_char_p(prompt) prompt = ctypes.c_char_p(prompt)
context = LLModelPromptContext( if self.context is None:
logits_size=logits_size, self.context = LLModelPromptContext(
tokens_size=tokens_size, logits_size=logits_size,
n_past=n_past, tokens_size=tokens_size,
n_ctx=n_ctx, n_past=n_past,
n_predict=n_predict, n_ctx=n_ctx,
top_k=top_k, n_predict=n_predict,
top_p=top_p, top_k=top_k,
temp=temp, top_p=top_p,
n_batch=n_batch, temp=temp,
repeat_penalty=repeat_penalty, n_batch=n_batch,
repeat_last_n=repeat_last_n, repeat_penalty=repeat_penalty,
context_erase=context_erase repeat_last_n=repeat_last_n,
) context_erase=context_erase
)
# Put response tokens into an output queue # Put response tokens into an output queue
def _generator_response_callback(token_id, response): def _generator_response_callback(token_id, response):
@ -305,7 +309,7 @@ class LLModel:
PromptCallback(self._prompt_callback), PromptCallback(self._prompt_callback),
ResponseCallback(_generator_response_callback), ResponseCallback(_generator_response_callback),
RecalculateCallback(self._recalculate_callback), RecalculateCallback(self._recalculate_callback),
context)) self.context))
thread.start() thread.start()
# Generator # Generator