mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-21 11:29:48 +00:00
fix chat-style prompt templates (#1970)
Also use a new version of Mistral OpenOrca. Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
@@ -89,10 +89,12 @@ RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool)
|
||||
llmodel.llmodel_prompt.argtypes = [
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_char_p,
|
||||
ctypes.c_char_p,
|
||||
PromptCallback,
|
||||
ResponseCallback,
|
||||
RecalculateCallback,
|
||||
ctypes.POINTER(LLModelPromptContext),
|
||||
ctypes.c_bool,
|
||||
]
|
||||
|
||||
llmodel.llmodel_prompt.restype = None
|
||||
@@ -290,6 +292,7 @@ class LLModel:
|
||||
def prompt_model(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: str,
|
||||
callback: ResponseCallbackType,
|
||||
n_predict: int = 4096,
|
||||
top_k: int = 40,
|
||||
@@ -300,6 +303,7 @@ class LLModel:
|
||||
repeat_last_n: int = 10,
|
||||
context_erase: float = 0.75,
|
||||
reset_context: bool = False,
|
||||
special: bool = False,
|
||||
):
|
||||
"""
|
||||
Generate response from model from a prompt.
|
||||
@@ -326,9 +330,6 @@ class LLModel:
|
||||
prompt,
|
||||
)
|
||||
|
||||
prompt_bytes = prompt.encode()
|
||||
prompt_ptr = ctypes.c_char_p(prompt_bytes)
|
||||
|
||||
self._set_context(
|
||||
n_predict=n_predict,
|
||||
top_k=top_k,
|
||||
@@ -343,16 +344,18 @@ class LLModel:
|
||||
|
||||
llmodel.llmodel_prompt(
|
||||
self.model,
|
||||
prompt_ptr,
|
||||
ctypes.c_char_p(prompt.encode()),
|
||||
ctypes.c_char_p(prompt_template.encode()),
|
||||
PromptCallback(self._prompt_callback),
|
||||
ResponseCallback(self._callback_decoder(callback)),
|
||||
RecalculateCallback(self._recalculate_callback),
|
||||
self.context,
|
||||
special,
|
||||
)
|
||||
|
||||
|
||||
def prompt_model_streaming(
|
||||
self, prompt: str, callback: ResponseCallbackType = empty_response_callback, **kwargs
|
||||
self, prompt: str, prompt_template: str, callback: ResponseCallbackType = empty_response_callback, **kwargs
|
||||
) -> Iterable[str]:
|
||||
output_queue: Queue[str | Sentinel] = Queue()
|
||||
|
||||
@@ -369,15 +372,15 @@ class LLModel:
|
||||
|
||||
return _generator_callback
|
||||
|
||||
def run_llmodel_prompt(prompt: str, callback: ResponseCallbackType, **kwargs):
|
||||
self.prompt_model(prompt, callback, **kwargs)
|
||||
def run_llmodel_prompt(prompt: str, prompt_template: str, callback: ResponseCallbackType, **kwargs):
|
||||
self.prompt_model(prompt, prompt_template, callback, **kwargs)
|
||||
output_queue.put(Sentinel.TERMINATING_SYMBOL)
|
||||
|
||||
# Kick off llmodel_prompt in separate thread so we can return generator
|
||||
# immediately
|
||||
thread = threading.Thread(
|
||||
target=run_llmodel_prompt,
|
||||
args=(prompt, _generator_callback_wrapper(callback)),
|
||||
args=(prompt, prompt_template, _generator_callback_wrapper(callback)),
|
||||
kwargs=kwargs,
|
||||
)
|
||||
thread.start()
|
||||
|
@@ -4,8 +4,10 @@ Python only API for running all GPT4All models.
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
@@ -314,6 +316,10 @@ class GPT4All:
|
||||
Either the entire completion or a generator that yields the completion token by token.
|
||||
"""
|
||||
|
||||
if re.search(r"%1(?![0-9])", self._current_prompt_template):
|
||||
raise ValueError("Prompt template containing a literal '%1' is not supported. For a prompt "
|
||||
"placeholder, please use '{0}' instead.")
|
||||
|
||||
# Preparing the model request
|
||||
generate_kwargs: Dict[str, Any] = dict(
|
||||
temp=temp,
|
||||
@@ -327,16 +333,29 @@ class GPT4All:
|
||||
|
||||
if self._is_chat_session_activated:
|
||||
# check if there is only one message, i.e. system prompt:
|
||||
generate_kwargs["reset_context"] = len(self.current_chat_session) == 1
|
||||
reset = len(self.current_chat_session) == 1
|
||||
generate_kwargs["reset_context"] = reset
|
||||
self.current_chat_session.append({"role": "user", "content": prompt})
|
||||
|
||||
prompt = self._format_chat_prompt_template(
|
||||
messages=self.current_chat_session[-1:],
|
||||
default_prompt_header=self.current_chat_session[0]["content"]
|
||||
if generate_kwargs["reset_context"]
|
||||
else "",
|
||||
)
|
||||
if self._format_chat_prompt_template.__func__ is GPT4All._format_chat_prompt_template:
|
||||
if reset:
|
||||
# ingest system prompt
|
||||
self.model.prompt_model(self.current_chat_session[0]["content"], "%1",
|
||||
n_batch=n_batch, n_predict=0, special=True)
|
||||
prompt_template = self._current_prompt_template.format("%1")
|
||||
else:
|
||||
warnings.warn(
|
||||
"_format_chat_prompt_template is deprecated. Please use a chat session with a prompt template.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
# special tokens won't be processed
|
||||
prompt = self._format_chat_prompt_template(
|
||||
self.current_chat_session[-1:],
|
||||
self.current_chat_session[0]["content"] if reset else "",
|
||||
)
|
||||
prompt_template = "%1"
|
||||
else:
|
||||
prompt_template = "%1"
|
||||
generate_kwargs["reset_context"] = True
|
||||
|
||||
# Prepare the callback, process the model response
|
||||
@@ -365,14 +384,16 @@ class GPT4All:
|
||||
# Send the request to the model
|
||||
if streaming:
|
||||
return self.model.prompt_model_streaming(
|
||||
prompt=prompt,
|
||||
callback=_callback_wrapper(callback, output_collector),
|
||||
prompt,
|
||||
prompt_template,
|
||||
_callback_wrapper(callback, output_collector),
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
self.model.prompt_model(
|
||||
prompt=prompt,
|
||||
callback=_callback_wrapper(callback, output_collector),
|
||||
prompt,
|
||||
prompt_template,
|
||||
_callback_wrapper(callback, output_collector),
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
@@ -423,24 +444,6 @@ class GPT4All:
|
||||
Formatted prompt.
|
||||
"""
|
||||
|
||||
if isinstance(default_prompt_header, bool):
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"Using True/False for the 'default_prompt_header' is deprecated. Use a string instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
default_prompt_header = ""
|
||||
|
||||
if isinstance(default_prompt_footer, bool):
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"Using True/False for the 'default_prompt_footer' is deprecated. Use a string instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
default_prompt_footer = ""
|
||||
|
||||
full_prompt = default_prompt_header + "\n\n" if default_prompt_header != "" else ""
|
||||
|
||||
for message in messages:
|
||||
|
Reference in New Issue
Block a user