fix prompt context so it's preserved in class

This commit is contained in:
Richard Guo 2023-06-12 22:38:50 -04:00
parent 85964a7635
commit a99cc34efb

View File

@ -125,6 +125,7 @@ class LLModel:
def __init__(self):
self.model = None
self.model_name = None
self.context = None
def __del__(self):
if self.model is not None:
@ -211,7 +212,9 @@ class LLModel:
sys.stdout = stream_processor
context = LLModelPromptContext(
if self.context is None:
self.context = LLModelPromptContext(
logits_size=logits_size,
tokens_size=tokens_size,
n_past=n_past,
@ -231,7 +234,7 @@ class LLModel:
PromptCallback(self._prompt_callback),
ResponseCallback(self._response_callback),
RecalculateCallback(self._recalculate_callback),
context)
self.context)
# Revert to old stdout
sys.stdout = old_stdout
@ -262,7 +265,8 @@ class LLModel:
prompt = prompt.encode('utf-8')
prompt = ctypes.c_char_p(prompt)
context = LLModelPromptContext(
if self.context is None:
self.context = LLModelPromptContext(
logits_size=logits_size,
tokens_size=tokens_size,
n_past=n_past,
@ -305,7 +309,7 @@ class LLModel:
PromptCallback(self._prompt_callback),
ResponseCallback(_generator_response_callback),
RecalculateCallback(self._recalculate_callback),
context))
self.context))
thread.start()
# Generator