fix chat-style prompt templates (#1970)

Also use a new version of Mistral OpenOrca.

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel
2024-02-21 15:45:32 -05:00
committed by GitHub
parent b8f5c74f40
commit 4fc4d94be4
22 changed files with 429 additions and 307 deletions

View File

@@ -89,10 +89,12 @@ RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool)
llmodel.llmodel_prompt.argtypes = [
ctypes.c_void_p,
ctypes.c_char_p,
ctypes.c_char_p,
PromptCallback,
ResponseCallback,
RecalculateCallback,
ctypes.POINTER(LLModelPromptContext),
ctypes.c_bool,
]
llmodel.llmodel_prompt.restype = None
@@ -290,6 +292,7 @@ class LLModel:
def prompt_model(
self,
prompt: str,
prompt_template: str,
callback: ResponseCallbackType,
n_predict: int = 4096,
top_k: int = 40,
@@ -300,6 +303,7 @@ class LLModel:
repeat_last_n: int = 10,
context_erase: float = 0.75,
reset_context: bool = False,
special: bool = False,
):
"""
Generate response from model from a prompt.
@@ -326,9 +330,6 @@ class LLModel:
prompt,
)
prompt_bytes = prompt.encode()
prompt_ptr = ctypes.c_char_p(prompt_bytes)
self._set_context(
n_predict=n_predict,
top_k=top_k,
@@ -343,16 +344,18 @@ class LLModel:
llmodel.llmodel_prompt(
self.model,
prompt_ptr,
ctypes.c_char_p(prompt.encode()),
ctypes.c_char_p(prompt_template.encode()),
PromptCallback(self._prompt_callback),
ResponseCallback(self._callback_decoder(callback)),
RecalculateCallback(self._recalculate_callback),
self.context,
special,
)
def prompt_model_streaming(
self, prompt: str, callback: ResponseCallbackType = empty_response_callback, **kwargs
self, prompt: str, prompt_template: str, callback: ResponseCallbackType = empty_response_callback, **kwargs
) -> Iterable[str]:
output_queue: Queue[str | Sentinel] = Queue()
@@ -369,15 +372,15 @@ class LLModel:
return _generator_callback
def run_llmodel_prompt(prompt: str, callback: ResponseCallbackType, **kwargs):
self.prompt_model(prompt, callback, **kwargs)
def run_llmodel_prompt(prompt: str, prompt_template: str, callback: ResponseCallbackType, **kwargs):
self.prompt_model(prompt, prompt_template, callback, **kwargs)
output_queue.put(Sentinel.TERMINATING_SYMBOL)
# Kick off llmodel_prompt in separate thread so we can return generator
# immediately
thread = threading.Thread(
target=run_llmodel_prompt,
args=(prompt, _generator_callback_wrapper(callback)),
args=(prompt, prompt_template, _generator_callback_wrapper(callback)),
kwargs=kwargs,
)
thread.start()

View File

@@ -4,8 +4,10 @@ Python only API for running all GPT4All models.
from __future__ import annotations
import os
import re
import sys
import time
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Union
@@ -314,6 +316,10 @@ class GPT4All:
Either the entire completion or a generator that yields the completion token by token.
"""
if re.search(r"%1(?![0-9])", self._current_prompt_template):
raise ValueError("Prompt template containing a literal '%1' is not supported. For a prompt "
"placeholder, please use '{0}' instead.")
# Preparing the model request
generate_kwargs: Dict[str, Any] = dict(
temp=temp,
@@ -327,16 +333,29 @@ class GPT4All:
if self._is_chat_session_activated:
# check if there is only one message, i.e. system prompt:
generate_kwargs["reset_context"] = len(self.current_chat_session) == 1
reset = len(self.current_chat_session) == 1
generate_kwargs["reset_context"] = reset
self.current_chat_session.append({"role": "user", "content": prompt})
prompt = self._format_chat_prompt_template(
messages=self.current_chat_session[-1:],
default_prompt_header=self.current_chat_session[0]["content"]
if generate_kwargs["reset_context"]
else "",
)
if self._format_chat_prompt_template.__func__ is GPT4All._format_chat_prompt_template:
if reset:
# ingest system prompt
self.model.prompt_model(self.current_chat_session[0]["content"], "%1",
n_batch=n_batch, n_predict=0, special=True)
prompt_template = self._current_prompt_template.format("%1")
else:
warnings.warn(
"_format_chat_prompt_template is deprecated. Please use a chat session with a prompt template.",
DeprecationWarning,
)
# special tokens won't be processed
prompt = self._format_chat_prompt_template(
self.current_chat_session[-1:],
self.current_chat_session[0]["content"] if reset else "",
)
prompt_template = "%1"
else:
prompt_template = "%1"
generate_kwargs["reset_context"] = True
# Prepare the callback, process the model response
@@ -365,14 +384,16 @@ class GPT4All:
# Send the request to the model
if streaming:
return self.model.prompt_model_streaming(
prompt=prompt,
callback=_callback_wrapper(callback, output_collector),
prompt,
prompt_template,
_callback_wrapper(callback, output_collector),
**generate_kwargs,
)
self.model.prompt_model(
prompt=prompt,
callback=_callback_wrapper(callback, output_collector),
prompt,
prompt_template,
_callback_wrapper(callback, output_collector),
**generate_kwargs,
)
@@ -423,24 +444,6 @@ class GPT4All:
Formatted prompt.
"""
if isinstance(default_prompt_header, bool):
import warnings
warnings.warn(
"Using True/False for the 'default_prompt_header' is deprecated. Use a string instead.",
DeprecationWarning,
)
default_prompt_header = ""
if isinstance(default_prompt_footer, bool):
import warnings
warnings.warn(
"Using True/False for the 'default_prompt_footer' is deprecated. Use a string instead.",
DeprecationWarning,
)
default_prompt_footer = ""
full_prompt = default_prompt_header + "\n\n" if default_prompt_header != "" else ""
for message in messages: