python: various fixes for GPT4All and Embed4All (#2130)

Key changes:
* honor empty system prompt argument
* current_chat_session is now read-only and defaults to None
* deprecate fallback prompt template for unknown models
* fix mistakes from #2086

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2024-03-15 11:49:58 -04:00 committed by GitHub
parent 53f109f519
commit 255568fb9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 132 additions and 148 deletions

View File

@ -10,6 +10,7 @@
#include <iomanip>
#include <iostream>
#include <map>
#include <numeric>
#include <random>
#include <sstream>
#include <stdexcept>
@ -345,7 +346,7 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
d_ptr->ctx_params.n_threads = d_ptr->n_threads;
d_ptr->ctx_params.n_threads_batch = d_ptr->n_threads;
if (m_supportsEmbedding)
if (isEmbedding)
d_ptr->ctx_params.embeddings = true;
d_ptr->ctx = llama_new_context_with_model(d_ptr->model, d_ptr->ctx_params);
@ -612,22 +613,22 @@ struct EmbModelGroup {
std::vector<const char *> names;
};
static const EmbModelSpec NOPREFIX_SPEC {nullptr, nullptr};
static const EmbModelSpec NOPREFIX_SPEC {"", ""};
static const EmbModelSpec NOMIC_SPEC {"search_document", "search_query", {"clustering", "classification"}};
static const EmbModelSpec E5_SPEC {"passage", "query"};
static const EmbModelSpec NOMIC_1_5_SPEC {
"search_document", "search_query", {"clustering", "classification"}, true, "[768, 512, 384, 256, 128]"
"search_document", "search_query", {"clustering", "classification"}, true, "[768, 512, 384, 256, 128]",
};
static const EmbModelSpec LLM_EMBEDDER_SPEC {
"Represent this document for retrieval",
"Represent this query for retrieving relevant documents",
};
static const EmbModelSpec BGE_SPEC {
nullptr, "Represent this sentence for searching relevant passages",
"", "Represent this sentence for searching relevant passages",
};
static const EmbModelSpec E5_MISTRAL_SPEC {
nullptr, "Instruct: Given a query, retrieve relevant passages that answer the query\nQuery",
"", "Instruct: Given a query, retrieve relevant passages that answer the query\nQuery",
};
static const EmbModelGroup EMBEDDING_MODEL_SPECS[] {
@ -738,18 +739,20 @@ void LLamaModel::embedInternal(
const llama_token bos_token = llama_token_bos(d_ptr->model);
const llama_token eos_token = llama_token_eos(d_ptr->model);
assert(shouldAddBOS());
bool addEOS = llama_vocab_type(d_ptr->model) == LLAMA_VOCAB_TYPE_WPM;
bool useBOS = shouldAddBOS();
bool useEOS = llama_vocab_type(d_ptr->model) == LLAMA_VOCAB_TYPE_WPM;
// no EOS, optional BOS
auto tokenize = [this, addEOS](std::string text, TokenString &tokens, bool addBOS) {
if (!text.empty() && text[0] != ' ')
auto tokenize = [this, useBOS, useEOS, eos_token](std::string text, TokenString &tokens, bool wantBOS) {
if (!text.empty() && text[0] != ' ') {
text = ' ' + text; // normalize for SPM - our fork of llama.cpp doesn't add a space prefix
}
wantBOS &= useBOS;
tokens.resize(text.length()+4);
int32_t n_tokens = llama_tokenize(d_ptr->model, text.c_str(), text.length(), tokens.data(), tokens.size(), addBOS, false);
assert(addEOS == (eos_token != -1 && tokens[n_tokens - 1] == eos_token));
tokens.resize(n_tokens - addEOS); // erase EOS/SEP
int32_t n_tokens = llama_tokenize(d_ptr->model, text.c_str(), text.length(), tokens.data(), tokens.size(), wantBOS, false);
assert(useEOS == (eos_token != -1 && tokens[n_tokens - 1] == eos_token));
tokens.resize(n_tokens - useEOS); // erase EOS/SEP
};
// tokenize the texts
@ -784,7 +787,7 @@ void LLamaModel::embedInternal(
}
const uint32_t n_batch = llama_n_batch(d_ptr->ctx);
const uint32_t max_len = n_batch - (prefixTokens.size() + addEOS); // minus BOS/CLS and EOS/SEP
const uint32_t max_len = n_batch - (prefixTokens.size() + useEOS); // minus BOS/CLS and EOS/SEP
if (chunkOverlap >= max_len) {
throw std::logic_error("max chunk length of " + std::to_string(max_len) + " is smaller than overlap of " +
std::to_string(chunkOverlap) + " tokens");

View File

@ -1,7 +1,6 @@
from __future__ import annotations
import ctypes
import logging
import os
import platform
import re
@ -17,8 +16,6 @@ if sys.version_info >= (3, 9):
else:
import importlib_resources
logger: logging.Logger = logging.getLogger(__name__)
# TODO: provide a config file to make this more robust
MODEL_LIB_PATH = importlib_resources.files("gpt4all") / "llmodel_DO_NOT_MODIFY" / "build"
@ -130,7 +127,7 @@ llmodel.llmodel_set_implementation_search_path.restype = None
llmodel.llmodel_threadCount.argtypes = [ctypes.c_void_p]
llmodel.llmodel_threadCount.restype = ctypes.c_int32
llmodel.llmodel_set_implementation_search_path(str(MODEL_LIB_PATH).replace("\\", r"\\").encode())
llmodel.llmodel_set_implementation_search_path(str(MODEL_LIB_PATH).encode())
llmodel.llmodel_available_gpu_devices.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.POINTER(ctypes.c_int32)]
llmodel.llmodel_available_gpu_devices.restype = ctypes.POINTER(LLModelGPUDevice)
@ -323,7 +320,7 @@ class LLModel:
ctypes.byref(error),
)
if embedding_ptr.value is None:
if not embedding_ptr:
msg = "(unknown error)" if error.value is None else error.value.decode()
raise RuntimeError(f'Failed to generate embeddings: {msg}')
@ -372,13 +369,6 @@ class LLModel:
self.buffer.clear()
self.buff_expecting_cont_bytes = 0
logger.info(
"LLModel.prompt_model -- prompt:\n"
+ "%s\n"
+ "===/LLModel.prompt_model -- prompt/===",
prompt,
)
self._set_context(
n_predict=n_predict,
top_k=top_k,

View File

@ -20,12 +20,9 @@ from urllib3.exceptions import IncompleteRead, ProtocolError
from . import _pyllmodel
# TODO: move to config
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\")
DEFAULT_MODEL_DIRECTORY = Path.home() / ".cache" / "gpt4all"
DEFAULT_MODEL_CONFIG = {
"systemPrompt": "",
"promptTemplate": "### Human: \n{0}\n\n### Assistant:\n",
}
DEFAULT_PROMPT_TEMPLATE = "### Human:\n{0}\n\n### Assistant:\n"
ConfigType = Dict[str, str]
MessageType = Dict[str, str]
@ -34,18 +31,19 @@ MessageType = Dict[str, str]
class Embed4All:
"""
Python class that handles embeddings for GPT4All.
Args:
model_name: The name of the embedding model to use. Defaults to `all-MiniLM-L6-v2.gguf2.f16.gguf`.
All other arguments are passed to the GPT4All constructor. See its documentation for more info.
"""
MIN_DIMENSIONALITY = 64
def __init__(self, model_name: Optional[str] = None, n_threads: Optional[int] = None, **kwargs):
"""
Constructor
Args:
n_threads: number of CPU threads used by GPT4All. Default is None, then the number of threads are determined automatically.
"""
self.gpt4all = GPT4All(model_name or 'all-MiniLM-L6-v2-f16.gguf', n_threads=n_threads, **kwargs)
def __init__(self, model_name: Optional[str] = None, **kwargs):
if model_name is None:
model_name = 'all-MiniLM-L6-v2.gguf2.f16.gguf'
self.gpt4all = GPT4All(model_name, **kwargs)
@overload
def embed(
@ -58,7 +56,7 @@ class Embed4All:
atlas: bool = ...,
) -> list[list[float]]: ...
def embed(self, text, prefix=None, dimensionality=None, long_text_mode="truncate", atlas=False):
def embed(self, text, prefix=None, dimensionality=None, long_text_mode="mean", atlas=False):
"""
Generate one or more embeddings.
@ -94,22 +92,6 @@ class Embed4All:
class GPT4All:
"""
Python class that handles instantiation, downloading, generation and chat with GPT4All models.
"""
def __init__(
self,
model_name: str,
model_path: Optional[Union[str, os.PathLike[str]]] = None,
model_type: Optional[str] = None,
allow_download: bool = True,
n_threads: Optional[int] = None,
device: Optional[str] = "cpu",
n_ctx: int = 2048,
ngl: int = 100,
verbose: bool = False,
):
"""
Constructor
Args:
model_name: Name of GPT4All or custom model. Including ".gguf" file extension is optional but encouraged.
@ -131,6 +113,19 @@ class GPT4All:
ngl: Number of GPU layers to use (Vulkan)
verbose: If True, print debug messages.
"""
def __init__(
self,
model_name: str,
model_path: Optional[Union[str, os.PathLike[str]]] = None,
model_type: Optional[str] = None,
allow_download: bool = True,
n_threads: Optional[int] = None,
device: Optional[str] = "cpu",
n_ctx: int = 2048,
ngl: int = 100,
verbose: bool = False,
):
self.model_type = model_type
# Retrieve model and download if allowed
self.config: ConfigType = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download, verbose=verbose)
@ -142,10 +137,13 @@ class GPT4All:
if n_threads is not None:
self.model.set_thread_count(n_threads)
self._is_chat_session_activated: bool = False
self.current_chat_session: List[MessageType] = empty_chat_session()
self._history: list[MessageType] | None = None
self._current_prompt_template: str = "{0}"
@property
def current_chat_session(self) -> list[MessageType] | None:
return self._history
@staticmethod
def list_models() -> List[ConfigType]:
"""
@ -159,8 +157,9 @@ class GPT4All:
raise ValueError(f'Request failed: HTTP {resp.status_code} {resp.reason}')
return resp.json()
@staticmethod
@classmethod
def retrieve_model(
cls,
model_name: str,
model_path: Optional[Union[str, os.PathLike[str]]] = None,
allow_download: bool = True,
@ -183,58 +182,51 @@ class GPT4All:
model_filename = append_extension_if_missing(model_name)
# get the config for the model
config: ConfigType = DEFAULT_MODEL_CONFIG
config: ConfigType = {}
if allow_download:
available_models = GPT4All.list_models()
available_models = cls.list_models()
for m in available_models:
if model_filename == m["filename"]:
config.update(m)
config["systemPrompt"] = config["systemPrompt"].strip()
tmpl = m.get("promptTemplate", DEFAULT_PROMPT_TEMPLATE)
# change to Python-style formatting
config["promptTemplate"] = config["promptTemplate"].replace("%1", "{0}", 1).replace("%2", "{1}", 1)
m["promptTemplate"] = tmpl.replace("%1", "{0}", 1).replace("%2", "{1}", 1)
config.update(m)
break
# Validate download directory
if model_path is None:
try:
os.makedirs(DEFAULT_MODEL_DIRECTORY, exist_ok=True)
except OSError as exc:
raise ValueError(
f"Failed to create model download directory at {DEFAULT_MODEL_DIRECTORY}: {exc}. "
"Please specify model_path."
)
except OSError as e:
raise RuntimeError("Failed to create model download directory") from e
model_path = DEFAULT_MODEL_DIRECTORY
else:
model_path = str(model_path).replace("\\", "\\\\")
model_path = Path(model_path)
if not os.path.exists(model_path):
raise ValueError(f"Invalid model directory: {model_path}")
if not model_path.exists():
raise FileNotFoundError(f"Model directory does not exist: {model_path!r}")
model_dest = os.path.join(model_path, model_filename).replace("\\", "\\\\")
if os.path.exists(model_dest):
config.pop("url", None)
config["path"] = model_dest
model_dest = model_path / model_filename
if model_dest.exists():
config["path"] = str(model_dest)
if verbose:
print("Found model file at", model_dest, file=sys.stderr)
# If model file does not exist, download
print(f"Found model file at {str(model_dest)!r}", file=sys.stderr)
elif allow_download:
url = config.pop("url", None)
config["path"] = GPT4All.download_model(model_filename, model_path, verbose=verbose, url=url)
# If model file does not exist, download
config["path"] = str(cls.download_model(model_filename, model_path, verbose=verbose, url=config.get("url")))
else:
raise ValueError("Failed to retrieve model")
raise FileNotFoundError(f"Model file does not exist: {model_dest!r}")
return config
@staticmethod
def download_model(
model_filename: str,
model_path: Union[str, os.PathLike[str]],
model_path: str | os.PathLike[str],
verbose: bool = True,
url: Optional[str] = None,
) -> str:
) -> str | os.PathLike[str]:
"""
Download model from https://gpt4all.io.
@ -248,21 +240,17 @@ class GPT4All:
Model file destination.
"""
def get_download_url(model_filename):
if url:
return url
return f"https://gpt4all.io/models/gguf/{model_filename}"
# Download model
download_path = os.path.join(model_path, model_filename).replace("\\", "\\\\")
download_url = get_download_url(model_filename)
download_path = Path(model_path) / model_filename
if url is None:
url = f"https://gpt4all.io/models/gguf/{model_filename}"
def make_request(offset=None):
headers = {}
if offset:
print(f"\nDownload interrupted, resuming from byte position {offset}", file=sys.stderr)
headers['Range'] = f'bytes={offset}-' # resume incomplete response
response = requests.get(download_url, stream=True, headers=headers)
response = requests.get(url, stream=True, headers=headers)
if response.status_code not in (200, 206):
raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}')
if offset and (response.status_code != 206 or str(offset) not in response.headers.get('Content-Range', '')):
@ -311,7 +299,7 @@ class GPT4All:
time.sleep(2) # Sleep for a little bit so Windows can remove file lock
if verbose:
print("Model downloaded at:", download_path, file=sys.stderr)
print(f"Model downloaded to {str(download_path)!r}", file=sys.stderr)
return download_path
def generate(
@ -350,10 +338,6 @@ class GPT4All:
Either the entire completion or a generator that yields the completion token by token.
"""
if re.search(r"%1(?![0-9])", self._current_prompt_template):
raise ValueError("Prompt template containing a literal '%1' is not supported. For a prompt "
"placeholder, please use '{0}' instead.")
# Preparing the model request
generate_kwargs: Dict[str, Any] = dict(
temp=temp,
@ -366,17 +350,17 @@ class GPT4All:
n_predict=n_predict if n_predict is not None else max_tokens,
)
if self._is_chat_session_activated:
if self._history is not None:
# check if there is only one message, i.e. system prompt:
reset = len(self.current_chat_session) == 1
reset = len(self._history) == 1
generate_kwargs["reset_context"] = reset
self.current_chat_session.append({"role": "user", "content": prompt})
self._history.append({"role": "user", "content": prompt})
fct_func = self._format_chat_prompt_template.__func__ # type: ignore[attr-defined]
if fct_func is GPT4All._format_chat_prompt_template:
if reset:
# ingest system prompt
self.model.prompt_model(self.current_chat_session[0]["content"], "%1",
self.model.prompt_model(self._history[0]["content"], "%1",
_pyllmodel.empty_response_callback,
n_batch=n_batch, n_predict=0, special=True)
prompt_template = self._current_prompt_template.format("%1", "%2")
@ -387,8 +371,8 @@ class GPT4All:
)
# special tokens won't be processed
prompt = self._format_chat_prompt_template(
self.current_chat_session[-1:],
self.current_chat_session[0]["content"] if reset else "",
self._history[-1:],
self._history[0]["content"] if reset else "",
)
prompt_template = "%1"
else:
@ -399,11 +383,11 @@ class GPT4All:
output_collector: List[MessageType]
output_collector = [
{"content": ""}
] # placeholder for the self.current_chat_session if chat session is not activated
] # placeholder for the self._history if chat session is not activated
if self._is_chat_session_activated:
self.current_chat_session.append({"role": "assistant", "content": ""})
output_collector = self.current_chat_session
if self._history is not None:
self._history.append({"role": "assistant", "content": ""})
output_collector = self._history
def _callback_wrapper(
callback: _pyllmodel.ResponseCallbackType,
@ -439,8 +423,8 @@ class GPT4All:
@contextmanager
def chat_session(
self,
system_prompt: str = "",
prompt_template: str = "",
system_prompt: str | None = None,
prompt_template: str | None = None,
):
"""
Context manager to hold an inference optimized chat session with a GPT4All model.
@ -449,16 +433,27 @@ class GPT4All:
system_prompt: An initial instruction for the model.
prompt_template: Template for the prompts with {0} being replaced by the user message.
"""
# Code to acquire resource, e.g.:
self._is_chat_session_activated = True
self.current_chat_session = empty_chat_session(system_prompt or self.config["systemPrompt"])
self._current_prompt_template = prompt_template or self.config["promptTemplate"]
if system_prompt is None:
system_prompt = self.config.get("systemPrompt", "")
if prompt_template is None:
if (tmpl := self.config.get("promptTemplate")) is None:
warnings.warn("Use of a sideloaded model or allow_download=False without specifying a prompt template "
"is deprecated. Defaulting to Alpaca.", DeprecationWarning)
tmpl = DEFAULT_PROMPT_TEMPLATE
prompt_template = tmpl
if re.search(r"%1(?![0-9])", prompt_template):
raise ValueError("Prompt template containing a literal '%1' is not supported. For a prompt "
"placeholder, please use '{0}' instead.")
self._history = [{"role": "system", "content": system_prompt}]
self._current_prompt_template = prompt_template
try:
yield self
finally:
# Code to release resource, e.g.:
self._is_chat_session_activated = False
self.current_chat_session = empty_chat_session()
self._history = None
self._current_prompt_template = "{0}"
def _format_chat_prompt_template(
@ -496,10 +491,6 @@ class GPT4All:
return full_prompt
def empty_chat_session(system_prompt: str = "") -> List[MessageType]:
return [{"role": "system", "content": system_prompt}]
def append_extension_if_missing(model_name):
if not model_name.endswith((".bin", ".gguf")):
model_name += ".gguf"

View File

@ -115,13 +115,13 @@ def test_empty_embedding():
output = embedder.embed(text)
def test_download_model(tmp_path: Path):
import gpt4all.gpt4all
old_default_dir = gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY
gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY = str(tmp_path) # temporary pytest directory to ensure a download happens
from gpt4all import gpt4all
old_default_dir = gpt4all.DEFAULT_MODEL_DIRECTORY
gpt4all.DEFAULT_MODEL_DIRECTORY = tmp_path # temporary pytest directory to ensure a download happens
try:
model = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin')
model_path = tmp_path / model.config['filename']
assert model_path.absolute() == Path(model.config['path']).absolute()
assert model_path.stat().st_size == int(model.config['filesize'])
finally:
gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY = old_default_dir
gpt4all.DEFAULT_MODEL_DIRECTORY = old_default_dir

View File

@ -24,7 +24,7 @@ const DEFAULT_LIBRARIES_DIRECTORY = librarySearchPaths.join(";");
const DEFAULT_MODEL_CONFIG = {
systemPrompt: "",
promptTemplate: "### Human: \n%1\n### Assistant:\n",
promptTemplate: "### Human:\n%1\n\n### Assistant:\n",
}
const DEFAULT_MODEL_LIST_URL = "https://gpt4all.io/models/models2.json";

View File

@ -29,7 +29,7 @@
"description": "<strong>Strong overall fast chat model</strong><br><ul><li>Fast responses</li><li>Chat based model</li><li>Trained by Mistral AI<li>Finetuned on OpenOrca dataset curated via <a href=\"https://atlas.nomic.ai/\">Nomic Atlas</a><li>Licensed for commercial use</ul>",
"url": "https://gpt4all.io/models/gguf/mistral-7b-openorca.gguf2.Q4_0.gguf",
"promptTemplate": "<|im_start|>user\n%1<|im_end|>\n<|im_start|>assistant\n%2<|im_end|>\n",
"systemPrompt": "<|im_start|>system\nYou are MistralOrca, a large language model trained by Alignment Lab AI. For multi-step problems, write out your reasoning for each step.\n<|im_end|>"
"systemPrompt": "<|im_start|>system\nYou are MistralOrca, a large language model trained by Alignment Lab AI. For multi-step problems, write out your reasoning for each step.\n<|im_end|>\n"
},
{
"order": "c",
@ -154,7 +154,7 @@
"description": "<strong>Good model with novel architecture</strong><br><ul><li>Fast responses<li>Chat based<li>Trained by Mosaic ML<li>Cannot be used commercially</ul>",
"url": "https://gpt4all.io/models/gguf/mpt-7b-chat-newbpe-q4_0.gguf",
"promptTemplate": "<|im_start|>user\n%1<|im_end|>\n<|im_start|>assistant\n%2<|im_end|>\n",
"systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>"
"systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>\n"
},
{
"order": "j",
@ -170,7 +170,7 @@
"description": "<strong>Good model with novel architecture</strong><br><ul><li>Fast responses<li>Chat based<li>Trained by Mosaic ML<li>Cannot be used commercially</ul>",
"url": "https://gpt4all.io/models/gguf/mpt-7b-chat.gguf4.Q4_0.gguf",
"promptTemplate": "<|im_start|>user\n%1<|im_end|>\n<|im_start|>assistant\n%2<|im_end|>\n",
"systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>"
"systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>\n"
},
{
"order": "k",