mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-31 23:27:17 +00:00
python: various fixes for GPT4All and Embed4All (#2130)
Key changes: * honor empty system prompt argument * current_chat_session is now read-only and defaults to None * deprecate fallback prompt template for unknown models * fix mistakes from #2086 Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
parent
53f109f519
commit
255568fb9a
@ -10,6 +10,7 @@
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <numeric>
|
||||
#include <random>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
@ -345,7 +346,7 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
|
||||
d_ptr->ctx_params.n_threads = d_ptr->n_threads;
|
||||
d_ptr->ctx_params.n_threads_batch = d_ptr->n_threads;
|
||||
|
||||
if (m_supportsEmbedding)
|
||||
if (isEmbedding)
|
||||
d_ptr->ctx_params.embeddings = true;
|
||||
|
||||
d_ptr->ctx = llama_new_context_with_model(d_ptr->model, d_ptr->ctx_params);
|
||||
@ -612,22 +613,22 @@ struct EmbModelGroup {
|
||||
std::vector<const char *> names;
|
||||
};
|
||||
|
||||
static const EmbModelSpec NOPREFIX_SPEC {nullptr, nullptr};
|
||||
static const EmbModelSpec NOPREFIX_SPEC {"", ""};
|
||||
static const EmbModelSpec NOMIC_SPEC {"search_document", "search_query", {"clustering", "classification"}};
|
||||
static const EmbModelSpec E5_SPEC {"passage", "query"};
|
||||
|
||||
static const EmbModelSpec NOMIC_1_5_SPEC {
|
||||
"search_document", "search_query", {"clustering", "classification"}, true, "[768, 512, 384, 256, 128]"
|
||||
"search_document", "search_query", {"clustering", "classification"}, true, "[768, 512, 384, 256, 128]",
|
||||
};
|
||||
static const EmbModelSpec LLM_EMBEDDER_SPEC {
|
||||
"Represent this document for retrieval",
|
||||
"Represent this query for retrieving relevant documents",
|
||||
};
|
||||
static const EmbModelSpec BGE_SPEC {
|
||||
nullptr, "Represent this sentence for searching relevant passages",
|
||||
"", "Represent this sentence for searching relevant passages",
|
||||
};
|
||||
static const EmbModelSpec E5_MISTRAL_SPEC {
|
||||
nullptr, "Instruct: Given a query, retrieve relevant passages that answer the query\nQuery",
|
||||
"", "Instruct: Given a query, retrieve relevant passages that answer the query\nQuery",
|
||||
};
|
||||
|
||||
static const EmbModelGroup EMBEDDING_MODEL_SPECS[] {
|
||||
@ -738,18 +739,20 @@ void LLamaModel::embedInternal(
|
||||
const llama_token bos_token = llama_token_bos(d_ptr->model);
|
||||
const llama_token eos_token = llama_token_eos(d_ptr->model);
|
||||
|
||||
assert(shouldAddBOS());
|
||||
bool addEOS = llama_vocab_type(d_ptr->model) == LLAMA_VOCAB_TYPE_WPM;
|
||||
bool useBOS = shouldAddBOS();
|
||||
bool useEOS = llama_vocab_type(d_ptr->model) == LLAMA_VOCAB_TYPE_WPM;
|
||||
|
||||
// no EOS, optional BOS
|
||||
auto tokenize = [this, addEOS](std::string text, TokenString &tokens, bool addBOS) {
|
||||
if (!text.empty() && text[0] != ' ')
|
||||
auto tokenize = [this, useBOS, useEOS, eos_token](std::string text, TokenString &tokens, bool wantBOS) {
|
||||
if (!text.empty() && text[0] != ' ') {
|
||||
text = ' ' + text; // normalize for SPM - our fork of llama.cpp doesn't add a space prefix
|
||||
}
|
||||
wantBOS &= useBOS;
|
||||
|
||||
tokens.resize(text.length()+4);
|
||||
int32_t n_tokens = llama_tokenize(d_ptr->model, text.c_str(), text.length(), tokens.data(), tokens.size(), addBOS, false);
|
||||
assert(addEOS == (eos_token != -1 && tokens[n_tokens - 1] == eos_token));
|
||||
tokens.resize(n_tokens - addEOS); // erase EOS/SEP
|
||||
int32_t n_tokens = llama_tokenize(d_ptr->model, text.c_str(), text.length(), tokens.data(), tokens.size(), wantBOS, false);
|
||||
assert(useEOS == (eos_token != -1 && tokens[n_tokens - 1] == eos_token));
|
||||
tokens.resize(n_tokens - useEOS); // erase EOS/SEP
|
||||
};
|
||||
|
||||
// tokenize the texts
|
||||
@ -784,7 +787,7 @@ void LLamaModel::embedInternal(
|
||||
}
|
||||
|
||||
const uint32_t n_batch = llama_n_batch(d_ptr->ctx);
|
||||
const uint32_t max_len = n_batch - (prefixTokens.size() + addEOS); // minus BOS/CLS and EOS/SEP
|
||||
const uint32_t max_len = n_batch - (prefixTokens.size() + useEOS); // minus BOS/CLS and EOS/SEP
|
||||
if (chunkOverlap >= max_len) {
|
||||
throw std::logic_error("max chunk length of " + std::to_string(max_len) + " is smaller than overlap of " +
|
||||
std::to_string(chunkOverlap) + " tokens");
|
||||
|
@ -317,10 +317,10 @@ are used instead of model-specific system and prompt templates:
|
||||
=== "Output"
|
||||
```
|
||||
default system template: ''
|
||||
default prompt template: '### Human: \n{0}\n\n### Assistant:\n'
|
||||
default prompt template: '### Human:\n{0}\n\n### Assistant:\n'
|
||||
|
||||
session system template: ''
|
||||
session prompt template: '### Human: \n{0}\n\n### Assistant:\n'
|
||||
session prompt template: '### Human:\n{0}\n\n### Assistant:\n'
|
||||
```
|
||||
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ctypes
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
@ -17,8 +16,6 @@ if sys.version_info >= (3, 9):
|
||||
else:
|
||||
import importlib_resources
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO: provide a config file to make this more robust
|
||||
MODEL_LIB_PATH = importlib_resources.files("gpt4all") / "llmodel_DO_NOT_MODIFY" / "build"
|
||||
@ -130,7 +127,7 @@ llmodel.llmodel_set_implementation_search_path.restype = None
|
||||
llmodel.llmodel_threadCount.argtypes = [ctypes.c_void_p]
|
||||
llmodel.llmodel_threadCount.restype = ctypes.c_int32
|
||||
|
||||
llmodel.llmodel_set_implementation_search_path(str(MODEL_LIB_PATH).replace("\\", r"\\").encode())
|
||||
llmodel.llmodel_set_implementation_search_path(str(MODEL_LIB_PATH).encode())
|
||||
|
||||
llmodel.llmodel_available_gpu_devices.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.POINTER(ctypes.c_int32)]
|
||||
llmodel.llmodel_available_gpu_devices.restype = ctypes.POINTER(LLModelGPUDevice)
|
||||
@ -323,7 +320,7 @@ class LLModel:
|
||||
ctypes.byref(error),
|
||||
)
|
||||
|
||||
if embedding_ptr.value is None:
|
||||
if not embedding_ptr:
|
||||
msg = "(unknown error)" if error.value is None else error.value.decode()
|
||||
raise RuntimeError(f'Failed to generate embeddings: {msg}')
|
||||
|
||||
@ -372,13 +369,6 @@ class LLModel:
|
||||
self.buffer.clear()
|
||||
self.buff_expecting_cont_bytes = 0
|
||||
|
||||
logger.info(
|
||||
"LLModel.prompt_model -- prompt:\n"
|
||||
+ "%s\n"
|
||||
+ "===/LLModel.prompt_model -- prompt/===",
|
||||
prompt,
|
||||
)
|
||||
|
||||
self._set_context(
|
||||
n_predict=n_predict,
|
||||
top_k=top_k,
|
||||
|
@ -20,12 +20,9 @@ from urllib3.exceptions import IncompleteRead, ProtocolError
|
||||
from . import _pyllmodel
|
||||
|
||||
# TODO: move to config
|
||||
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\")
|
||||
DEFAULT_MODEL_DIRECTORY = Path.home() / ".cache" / "gpt4all"
|
||||
|
||||
DEFAULT_MODEL_CONFIG = {
|
||||
"systemPrompt": "",
|
||||
"promptTemplate": "### Human: \n{0}\n\n### Assistant:\n",
|
||||
}
|
||||
DEFAULT_PROMPT_TEMPLATE = "### Human:\n{0}\n\n### Assistant:\n"
|
||||
|
||||
ConfigType = Dict[str, str]
|
||||
MessageType = Dict[str, str]
|
||||
@ -34,18 +31,19 @@ MessageType = Dict[str, str]
|
||||
class Embed4All:
|
||||
"""
|
||||
Python class that handles embeddings for GPT4All.
|
||||
|
||||
Args:
|
||||
model_name: The name of the embedding model to use. Defaults to `all-MiniLM-L6-v2.gguf2.f16.gguf`.
|
||||
|
||||
All other arguments are passed to the GPT4All constructor. See its documentation for more info.
|
||||
"""
|
||||
|
||||
MIN_DIMENSIONALITY = 64
|
||||
|
||||
def __init__(self, model_name: Optional[str] = None, n_threads: Optional[int] = None, **kwargs):
|
||||
"""
|
||||
Constructor
|
||||
|
||||
Args:
|
||||
n_threads: number of CPU threads used by GPT4All. Default is None, then the number of threads are determined automatically.
|
||||
"""
|
||||
self.gpt4all = GPT4All(model_name or 'all-MiniLM-L6-v2-f16.gguf', n_threads=n_threads, **kwargs)
|
||||
def __init__(self, model_name: Optional[str] = None, **kwargs):
|
||||
if model_name is None:
|
||||
model_name = 'all-MiniLM-L6-v2.gguf2.f16.gguf'
|
||||
self.gpt4all = GPT4All(model_name, **kwargs)
|
||||
|
||||
@overload
|
||||
def embed(
|
||||
@ -58,7 +56,7 @@ class Embed4All:
|
||||
atlas: bool = ...,
|
||||
) -> list[list[float]]: ...
|
||||
|
||||
def embed(self, text, prefix=None, dimensionality=None, long_text_mode="truncate", atlas=False):
|
||||
def embed(self, text, prefix=None, dimensionality=None, long_text_mode="mean", atlas=False):
|
||||
"""
|
||||
Generate one or more embeddings.
|
||||
|
||||
@ -94,6 +92,26 @@ class Embed4All:
|
||||
class GPT4All:
|
||||
"""
|
||||
Python class that handles instantiation, downloading, generation and chat with GPT4All models.
|
||||
|
||||
Args:
|
||||
model_name: Name of GPT4All or custom model. Including ".gguf" file extension is optional but encouraged.
|
||||
model_path: Path to directory containing model file or, if file does not exist, where to download model.
|
||||
Default is None, in which case models will be stored in `~/.cache/gpt4all/`.
|
||||
model_type: Model architecture. This argument currently does not have any functionality and is just used as
|
||||
descriptive identifier for user. Default is None.
|
||||
allow_download: Allow API to download models from gpt4all.io. Default is True.
|
||||
n_threads: number of CPU threads used by GPT4All. Default is None, then the number of threads are determined automatically.
|
||||
device: The processing unit on which the GPT4All model will run. It can be set to:
|
||||
- "cpu": Model will run on the central processing unit.
|
||||
- "gpu": Model will run on the best available graphics processing unit, irrespective of its vendor.
|
||||
- "amd", "nvidia", "intel": Model will run on the best available GPU from the specified vendor.
|
||||
Alternatively, a specific GPU name can also be provided, and the model will run on the GPU that matches the name if it's available.
|
||||
Default is "cpu".
|
||||
|
||||
Note: If a selected GPU device does not have sufficient RAM to accommodate the model, an error will be thrown, and the GPT4All instance will be rendered invalid. It's advised to ensure the device has enough memory before initiating the model.
|
||||
n_ctx: Maximum size of context window
|
||||
ngl: Number of GPU layers to use (Vulkan)
|
||||
verbose: If True, print debug messages.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -108,29 +126,6 @@ class GPT4All:
|
||||
ngl: int = 100,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""
|
||||
Constructor
|
||||
|
||||
Args:
|
||||
model_name: Name of GPT4All or custom model. Including ".gguf" file extension is optional but encouraged.
|
||||
model_path: Path to directory containing model file or, if file does not exist, where to download model.
|
||||
Default is None, in which case models will be stored in `~/.cache/gpt4all/`.
|
||||
model_type: Model architecture. This argument currently does not have any functionality and is just used as
|
||||
descriptive identifier for user. Default is None.
|
||||
allow_download: Allow API to download models from gpt4all.io. Default is True.
|
||||
n_threads: number of CPU threads used by GPT4All. Default is None, then the number of threads are determined automatically.
|
||||
device: The processing unit on which the GPT4All model will run. It can be set to:
|
||||
- "cpu": Model will run on the central processing unit.
|
||||
- "gpu": Model will run on the best available graphics processing unit, irrespective of its vendor.
|
||||
- "amd", "nvidia", "intel": Model will run on the best available GPU from the specified vendor.
|
||||
Alternatively, a specific GPU name can also be provided, and the model will run on the GPU that matches the name if it's available.
|
||||
Default is "cpu".
|
||||
|
||||
Note: If a selected GPU device does not have sufficient RAM to accommodate the model, an error will be thrown, and the GPT4All instance will be rendered invalid. It's advised to ensure the device has enough memory before initiating the model.
|
||||
n_ctx: Maximum size of context window
|
||||
ngl: Number of GPU layers to use (Vulkan)
|
||||
verbose: If True, print debug messages.
|
||||
"""
|
||||
self.model_type = model_type
|
||||
# Retrieve model and download if allowed
|
||||
self.config: ConfigType = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download, verbose=verbose)
|
||||
@ -142,10 +137,13 @@ class GPT4All:
|
||||
if n_threads is not None:
|
||||
self.model.set_thread_count(n_threads)
|
||||
|
||||
self._is_chat_session_activated: bool = False
|
||||
self.current_chat_session: List[MessageType] = empty_chat_session()
|
||||
self._history: list[MessageType] | None = None
|
||||
self._current_prompt_template: str = "{0}"
|
||||
|
||||
@property
|
||||
def current_chat_session(self) -> list[MessageType] | None:
|
||||
return self._history
|
||||
|
||||
@staticmethod
|
||||
def list_models() -> List[ConfigType]:
|
||||
"""
|
||||
@ -159,8 +157,9 @@ class GPT4All:
|
||||
raise ValueError(f'Request failed: HTTP {resp.status_code} {resp.reason}')
|
||||
return resp.json()
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def retrieve_model(
|
||||
cls,
|
||||
model_name: str,
|
||||
model_path: Optional[Union[str, os.PathLike[str]]] = None,
|
||||
allow_download: bool = True,
|
||||
@ -183,58 +182,51 @@ class GPT4All:
|
||||
model_filename = append_extension_if_missing(model_name)
|
||||
|
||||
# get the config for the model
|
||||
config: ConfigType = DEFAULT_MODEL_CONFIG
|
||||
config: ConfigType = {}
|
||||
if allow_download:
|
||||
available_models = GPT4All.list_models()
|
||||
available_models = cls.list_models()
|
||||
|
||||
for m in available_models:
|
||||
if model_filename == m["filename"]:
|
||||
config.update(m)
|
||||
config["systemPrompt"] = config["systemPrompt"].strip()
|
||||
tmpl = m.get("promptTemplate", DEFAULT_PROMPT_TEMPLATE)
|
||||
# change to Python-style formatting
|
||||
config["promptTemplate"] = config["promptTemplate"].replace("%1", "{0}", 1).replace("%2", "{1}", 1)
|
||||
m["promptTemplate"] = tmpl.replace("%1", "{0}", 1).replace("%2", "{1}", 1)
|
||||
config.update(m)
|
||||
break
|
||||
|
||||
# Validate download directory
|
||||
if model_path is None:
|
||||
try:
|
||||
os.makedirs(DEFAULT_MODEL_DIRECTORY, exist_ok=True)
|
||||
except OSError as exc:
|
||||
raise ValueError(
|
||||
f"Failed to create model download directory at {DEFAULT_MODEL_DIRECTORY}: {exc}. "
|
||||
"Please specify model_path."
|
||||
)
|
||||
except OSError as e:
|
||||
raise RuntimeError("Failed to create model download directory") from e
|
||||
model_path = DEFAULT_MODEL_DIRECTORY
|
||||
else:
|
||||
model_path = str(model_path).replace("\\", "\\\\")
|
||||
model_path = Path(model_path)
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
raise ValueError(f"Invalid model directory: {model_path}")
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Model directory does not exist: {model_path!r}")
|
||||
|
||||
model_dest = os.path.join(model_path, model_filename).replace("\\", "\\\\")
|
||||
if os.path.exists(model_dest):
|
||||
config.pop("url", None)
|
||||
config["path"] = model_dest
|
||||
model_dest = model_path / model_filename
|
||||
if model_dest.exists():
|
||||
config["path"] = str(model_dest)
|
||||
if verbose:
|
||||
print("Found model file at", model_dest, file=sys.stderr)
|
||||
|
||||
# If model file does not exist, download
|
||||
print(f"Found model file at {str(model_dest)!r}", file=sys.stderr)
|
||||
elif allow_download:
|
||||
url = config.pop("url", None)
|
||||
|
||||
config["path"] = GPT4All.download_model(model_filename, model_path, verbose=verbose, url=url)
|
||||
# If model file does not exist, download
|
||||
config["path"] = str(cls.download_model(model_filename, model_path, verbose=verbose, url=config.get("url")))
|
||||
else:
|
||||
raise ValueError("Failed to retrieve model")
|
||||
raise FileNotFoundError(f"Model file does not exist: {model_dest!r}")
|
||||
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
def download_model(
|
||||
model_filename: str,
|
||||
model_path: Union[str, os.PathLike[str]],
|
||||
model_path: str | os.PathLike[str],
|
||||
verbose: bool = True,
|
||||
url: Optional[str] = None,
|
||||
) -> str:
|
||||
) -> str | os.PathLike[str]:
|
||||
"""
|
||||
Download model from https://gpt4all.io.
|
||||
|
||||
@ -248,21 +240,17 @@ class GPT4All:
|
||||
Model file destination.
|
||||
"""
|
||||
|
||||
def get_download_url(model_filename):
|
||||
if url:
|
||||
return url
|
||||
return f"https://gpt4all.io/models/gguf/{model_filename}"
|
||||
|
||||
# Download model
|
||||
download_path = os.path.join(model_path, model_filename).replace("\\", "\\\\")
|
||||
download_url = get_download_url(model_filename)
|
||||
download_path = Path(model_path) / model_filename
|
||||
if url is None:
|
||||
url = f"https://gpt4all.io/models/gguf/{model_filename}"
|
||||
|
||||
def make_request(offset=None):
|
||||
headers = {}
|
||||
if offset:
|
||||
print(f"\nDownload interrupted, resuming from byte position {offset}", file=sys.stderr)
|
||||
headers['Range'] = f'bytes={offset}-' # resume incomplete response
|
||||
response = requests.get(download_url, stream=True, headers=headers)
|
||||
response = requests.get(url, stream=True, headers=headers)
|
||||
if response.status_code not in (200, 206):
|
||||
raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}')
|
||||
if offset and (response.status_code != 206 or str(offset) not in response.headers.get('Content-Range', '')):
|
||||
@ -311,7 +299,7 @@ class GPT4All:
|
||||
time.sleep(2) # Sleep for a little bit so Windows can remove file lock
|
||||
|
||||
if verbose:
|
||||
print("Model downloaded at:", download_path, file=sys.stderr)
|
||||
print(f"Model downloaded to {str(download_path)!r}", file=sys.stderr)
|
||||
return download_path
|
||||
|
||||
def generate(
|
||||
@ -350,10 +338,6 @@ class GPT4All:
|
||||
Either the entire completion or a generator that yields the completion token by token.
|
||||
"""
|
||||
|
||||
if re.search(r"%1(?![0-9])", self._current_prompt_template):
|
||||
raise ValueError("Prompt template containing a literal '%1' is not supported. For a prompt "
|
||||
"placeholder, please use '{0}' instead.")
|
||||
|
||||
# Preparing the model request
|
||||
generate_kwargs: Dict[str, Any] = dict(
|
||||
temp=temp,
|
||||
@ -366,17 +350,17 @@ class GPT4All:
|
||||
n_predict=n_predict if n_predict is not None else max_tokens,
|
||||
)
|
||||
|
||||
if self._is_chat_session_activated:
|
||||
if self._history is not None:
|
||||
# check if there is only one message, i.e. system prompt:
|
||||
reset = len(self.current_chat_session) == 1
|
||||
reset = len(self._history) == 1
|
||||
generate_kwargs["reset_context"] = reset
|
||||
self.current_chat_session.append({"role": "user", "content": prompt})
|
||||
self._history.append({"role": "user", "content": prompt})
|
||||
|
||||
fct_func = self._format_chat_prompt_template.__func__ # type: ignore[attr-defined]
|
||||
if fct_func is GPT4All._format_chat_prompt_template:
|
||||
if reset:
|
||||
# ingest system prompt
|
||||
self.model.prompt_model(self.current_chat_session[0]["content"], "%1",
|
||||
self.model.prompt_model(self._history[0]["content"], "%1",
|
||||
_pyllmodel.empty_response_callback,
|
||||
n_batch=n_batch, n_predict=0, special=True)
|
||||
prompt_template = self._current_prompt_template.format("%1", "%2")
|
||||
@ -387,8 +371,8 @@ class GPT4All:
|
||||
)
|
||||
# special tokens won't be processed
|
||||
prompt = self._format_chat_prompt_template(
|
||||
self.current_chat_session[-1:],
|
||||
self.current_chat_session[0]["content"] if reset else "",
|
||||
self._history[-1:],
|
||||
self._history[0]["content"] if reset else "",
|
||||
)
|
||||
prompt_template = "%1"
|
||||
else:
|
||||
@ -399,11 +383,11 @@ class GPT4All:
|
||||
output_collector: List[MessageType]
|
||||
output_collector = [
|
||||
{"content": ""}
|
||||
] # placeholder for the self.current_chat_session if chat session is not activated
|
||||
] # placeholder for the self._history if chat session is not activated
|
||||
|
||||
if self._is_chat_session_activated:
|
||||
self.current_chat_session.append({"role": "assistant", "content": ""})
|
||||
output_collector = self.current_chat_session
|
||||
if self._history is not None:
|
||||
self._history.append({"role": "assistant", "content": ""})
|
||||
output_collector = self._history
|
||||
|
||||
def _callback_wrapper(
|
||||
callback: _pyllmodel.ResponseCallbackType,
|
||||
@ -439,8 +423,8 @@ class GPT4All:
|
||||
@contextmanager
|
||||
def chat_session(
|
||||
self,
|
||||
system_prompt: str = "",
|
||||
prompt_template: str = "",
|
||||
system_prompt: str | None = None,
|
||||
prompt_template: str | None = None,
|
||||
):
|
||||
"""
|
||||
Context manager to hold an inference optimized chat session with a GPT4All model.
|
||||
@ -449,16 +433,27 @@ class GPT4All:
|
||||
system_prompt: An initial instruction for the model.
|
||||
prompt_template: Template for the prompts with {0} being replaced by the user message.
|
||||
"""
|
||||
# Code to acquire resource, e.g.:
|
||||
self._is_chat_session_activated = True
|
||||
self.current_chat_session = empty_chat_session(system_prompt or self.config["systemPrompt"])
|
||||
self._current_prompt_template = prompt_template or self.config["promptTemplate"]
|
||||
|
||||
if system_prompt is None:
|
||||
system_prompt = self.config.get("systemPrompt", "")
|
||||
|
||||
if prompt_template is None:
|
||||
if (tmpl := self.config.get("promptTemplate")) is None:
|
||||
warnings.warn("Use of a sideloaded model or allow_download=False without specifying a prompt template "
|
||||
"is deprecated. Defaulting to Alpaca.", DeprecationWarning)
|
||||
tmpl = DEFAULT_PROMPT_TEMPLATE
|
||||
prompt_template = tmpl
|
||||
|
||||
if re.search(r"%1(?![0-9])", prompt_template):
|
||||
raise ValueError("Prompt template containing a literal '%1' is not supported. For a prompt "
|
||||
"placeholder, please use '{0}' instead.")
|
||||
|
||||
self._history = [{"role": "system", "content": system_prompt}]
|
||||
self._current_prompt_template = prompt_template
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
# Code to release resource, e.g.:
|
||||
self._is_chat_session_activated = False
|
||||
self.current_chat_session = empty_chat_session()
|
||||
self._history = None
|
||||
self._current_prompt_template = "{0}"
|
||||
|
||||
def _format_chat_prompt_template(
|
||||
@ -496,10 +491,6 @@ class GPT4All:
|
||||
return full_prompt
|
||||
|
||||
|
||||
def empty_chat_session(system_prompt: str = "") -> List[MessageType]:
|
||||
return [{"role": "system", "content": system_prompt}]
|
||||
|
||||
|
||||
def append_extension_if_missing(model_name):
|
||||
if not model_name.endswith((".bin", ".gguf")):
|
||||
model_name += ".gguf"
|
||||
|
@ -115,13 +115,13 @@ def test_empty_embedding():
|
||||
output = embedder.embed(text)
|
||||
|
||||
def test_download_model(tmp_path: Path):
|
||||
import gpt4all.gpt4all
|
||||
old_default_dir = gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY
|
||||
gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY = str(tmp_path) # temporary pytest directory to ensure a download happens
|
||||
from gpt4all import gpt4all
|
||||
old_default_dir = gpt4all.DEFAULT_MODEL_DIRECTORY
|
||||
gpt4all.DEFAULT_MODEL_DIRECTORY = tmp_path # temporary pytest directory to ensure a download happens
|
||||
try:
|
||||
model = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin')
|
||||
model_path = tmp_path / model.config['filename']
|
||||
assert model_path.absolute() == Path(model.config['path']).absolute()
|
||||
assert model_path.stat().st_size == int(model.config['filesize'])
|
||||
finally:
|
||||
gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY = old_default_dir
|
||||
gpt4all.DEFAULT_MODEL_DIRECTORY = old_default_dir
|
||||
|
@ -24,7 +24,7 @@ const DEFAULT_LIBRARIES_DIRECTORY = librarySearchPaths.join(";");
|
||||
|
||||
const DEFAULT_MODEL_CONFIG = {
|
||||
systemPrompt: "",
|
||||
promptTemplate: "### Human: \n%1\n### Assistant:\n",
|
||||
promptTemplate: "### Human:\n%1\n\n### Assistant:\n",
|
||||
}
|
||||
|
||||
const DEFAULT_MODEL_LIST_URL = "https://gpt4all.io/models/models2.json";
|
||||
|
@ -29,7 +29,7 @@
|
||||
"description": "<strong>Strong overall fast chat model</strong><br><ul><li>Fast responses</li><li>Chat based model</li><li>Trained by Mistral AI<li>Finetuned on OpenOrca dataset curated via <a href=\"https://atlas.nomic.ai/\">Nomic Atlas</a><li>Licensed for commercial use</ul>",
|
||||
"url": "https://gpt4all.io/models/gguf/mistral-7b-openorca.gguf2.Q4_0.gguf",
|
||||
"promptTemplate": "<|im_start|>user\n%1<|im_end|>\n<|im_start|>assistant\n%2<|im_end|>\n",
|
||||
"systemPrompt": "<|im_start|>system\nYou are MistralOrca, a large language model trained by Alignment Lab AI. For multi-step problems, write out your reasoning for each step.\n<|im_end|>"
|
||||
"systemPrompt": "<|im_start|>system\nYou are MistralOrca, a large language model trained by Alignment Lab AI. For multi-step problems, write out your reasoning for each step.\n<|im_end|>\n"
|
||||
},
|
||||
{
|
||||
"order": "c",
|
||||
@ -42,7 +42,7 @@
|
||||
"parameters": "7 billion",
|
||||
"quant": "q4_0",
|
||||
"type": "Mistral",
|
||||
"systemPrompt": " ",
|
||||
"systemPrompt": "",
|
||||
"description": "<strong>Strong overall fast instruction following model</strong><br><ul><li>Fast responses</li><li>Trained by Mistral AI<li>Uncensored</li><li>Licensed for commercial use</li></ul>",
|
||||
"url": "https://gpt4all.io/models/gguf/mistral-7b-instruct-v0.1.Q4_0.gguf",
|
||||
"promptTemplate": "[INST] %1 [/INST]"
|
||||
@ -58,7 +58,7 @@
|
||||
"parameters": "7 billion",
|
||||
"quant": "q4_0",
|
||||
"type": "Falcon",
|
||||
"systemPrompt": " ",
|
||||
"systemPrompt": "",
|
||||
"description": "<strong>Very fast model with good quality</strong><br><ul><li>Fastest responses</li><li>Instruction based</li><li>Trained by TII<li>Finetuned by Nomic AI<li>Licensed for commercial use</ul>",
|
||||
"url": "https://gpt4all.io/models/gguf/gpt4all-falcon-newbpe-q4_0.gguf",
|
||||
"promptTemplate": "### Instruction:\n%1\n\n### Response:\n"
|
||||
@ -74,7 +74,7 @@
|
||||
"parameters": "7 billion",
|
||||
"quant": "q4_0",
|
||||
"type": "LLaMA2",
|
||||
"systemPrompt": " ",
|
||||
"systemPrompt": "",
|
||||
"description": "<ul><li>Instruction based<li>Trained by Microsoft<li>Cannot be used commercially</ul>",
|
||||
"url": "https://gpt4all.io/models/gguf/orca-2-7b.Q4_0.gguf"
|
||||
},
|
||||
@ -89,7 +89,7 @@
|
||||
"parameters": "13 billion",
|
||||
"quant": "q4_0",
|
||||
"type": "LLaMA2",
|
||||
"systemPrompt": " ",
|
||||
"systemPrompt": "",
|
||||
"description": "<ul><li>Instruction based<li>Trained by Microsoft<li>Cannot be used commercially</ul>",
|
||||
"url": "https://gpt4all.io/models/gguf/orca-2-13b.Q4_0.gguf"
|
||||
},
|
||||
@ -104,7 +104,7 @@
|
||||
"parameters": "13 billion",
|
||||
"quant": "q4_0",
|
||||
"type": "LLaMA2",
|
||||
"systemPrompt": " ",
|
||||
"systemPrompt": "",
|
||||
"description": "<strong>Strong overall larger model</strong><br><ul><li>Instruction based<li>Gives very long responses<li>Finetuned with only 1k of high-quality data<li>Trained by Microsoft and Peking University<li>Cannot be used commercially</ul>",
|
||||
"url": "https://gpt4all.io/models/gguf/wizardlm-13b-v1.2.Q4_0.gguf"
|
||||
},
|
||||
@ -119,7 +119,7 @@
|
||||
"parameters": "13 billion",
|
||||
"quant": "q4_0",
|
||||
"type": "LLaMA2",
|
||||
"systemPrompt": " ",
|
||||
"systemPrompt": "",
|
||||
"description": "<strong>Extremely good model</strong><br><ul><li>Instruction based<li>Gives long responses<li>Curated with 300,000 uncensored instructions<li>Trained by Nous Research<li>Cannot be used commercially</ul>",
|
||||
"url": "https://gpt4all.io/models/gguf/nous-hermes-llama2-13b.Q4_0.gguf",
|
||||
"promptTemplate": "### Instruction:\n%1\n\n### Response:\n"
|
||||
@ -135,7 +135,7 @@
|
||||
"parameters": "13 billion",
|
||||
"quant": "q4_0",
|
||||
"type": "LLaMA",
|
||||
"systemPrompt": " ",
|
||||
"systemPrompt": "",
|
||||
"description": "<strong>Very good overall model</strong><br><ul><li>Instruction based<li>Based on the same dataset as Groovy<li>Slower than Groovy, with higher quality responses<li>Trained by Nomic AI<li>Cannot be used commercially</ul>",
|
||||
"url": "https://gpt4all.io/models/gguf/gpt4all-13b-snoozy-q4_0.gguf"
|
||||
},
|
||||
@ -154,7 +154,7 @@
|
||||
"description": "<strong>Good model with novel architecture</strong><br><ul><li>Fast responses<li>Chat based<li>Trained by Mosaic ML<li>Cannot be used commercially</ul>",
|
||||
"url": "https://gpt4all.io/models/gguf/mpt-7b-chat-newbpe-q4_0.gguf",
|
||||
"promptTemplate": "<|im_start|>user\n%1<|im_end|>\n<|im_start|>assistant\n%2<|im_end|>\n",
|
||||
"systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>"
|
||||
"systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>\n"
|
||||
},
|
||||
{
|
||||
"order": "j",
|
||||
@ -170,7 +170,7 @@
|
||||
"description": "<strong>Good model with novel architecture</strong><br><ul><li>Fast responses<li>Chat based<li>Trained by Mosaic ML<li>Cannot be used commercially</ul>",
|
||||
"url": "https://gpt4all.io/models/gguf/mpt-7b-chat.gguf4.Q4_0.gguf",
|
||||
"promptTemplate": "<|im_start|>user\n%1<|im_end|>\n<|im_start|>assistant\n%2<|im_end|>\n",
|
||||
"systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>"
|
||||
"systemPrompt": "<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.<|im_end|>\n"
|
||||
},
|
||||
{
|
||||
"order": "k",
|
||||
@ -200,7 +200,7 @@
|
||||
"parameters": "3 billion",
|
||||
"quant": "q4_0",
|
||||
"type": "Replit",
|
||||
"systemPrompt": " ",
|
||||
"systemPrompt": "",
|
||||
"promptTemplate": "%1",
|
||||
"description": "<strong>Trained on subset of the Stack</strong><br><ul><li>Code completion based<li>Licensed for commercial use<li>WARNING: Not available for chat GUI</ul>",
|
||||
"url": "https://gpt4all.io/models/gguf/replit-code-v1_5-3b-newbpe-q4_0.gguf"
|
||||
@ -217,7 +217,7 @@
|
||||
"parameters": "7 billion",
|
||||
"quant": "q4_0",
|
||||
"type": "Starcoder",
|
||||
"systemPrompt": " ",
|
||||
"systemPrompt": "",
|
||||
"promptTemplate": "%1",
|
||||
"description": "<strong>Trained on subset of the Stack</strong><br><ul><li>Code completion based<li>WARNING: Not available for chat GUI</ul>",
|
||||
"url": "https://gpt4all.io/models/gguf/starcoder-newbpe-q4_0.gguf"
|
||||
@ -234,7 +234,7 @@
|
||||
"parameters": "7 billion",
|
||||
"quant": "q4_0",
|
||||
"type": "LLaMA",
|
||||
"systemPrompt": " ",
|
||||
"systemPrompt": "",
|
||||
"promptTemplate": "%1",
|
||||
"description": "<strong>Trained on collection of Python and TypeScript</strong><br><ul><li>Code completion based<li>WARNING: Not available for chat GUI</li>",
|
||||
"url": "https://gpt4all.io/models/gguf/rift-coder-v0-7b-q4_0.gguf"
|
||||
@ -253,7 +253,7 @@
|
||||
"quant": "f16",
|
||||
"type": "Bert",
|
||||
"embeddingModel": true,
|
||||
"systemPrompt": " ",
|
||||
"systemPrompt": "",
|
||||
"description": "<strong>LocalDocs text embeddings model</strong><br><ul><li>For use with LocalDocs feature<li>Used for retrieval augmented generation (RAG)",
|
||||
"url": "https://gpt4all.io/models/gguf/all-MiniLM-L6-v2-f16.gguf"
|
||||
},
|
||||
|
Loading…
Reference in New Issue
Block a user