Implement configurable context length (#1749)

This commit is contained in:
Jared Van Bortel 2023-12-16 17:58:15 -05:00 committed by GitHub
parent 7aa0f779de
commit d1c56b8b28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 291 additions and 135 deletions

View File

@ -714,8 +714,9 @@ Bert::~Bert() {
bert_free(d_ptr->ctx); bert_free(d_ptr->ctx);
} }
bool Bert::loadModel(const std::string &modelPath) bool Bert::loadModel(const std::string &modelPath, int n_ctx)
{ {
(void)n_ctx;
d_ptr->ctx = bert_load_from_file(modelPath.c_str()); d_ptr->ctx = bert_load_from_file(modelPath.c_str());
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
d_ptr->modelLoaded = d_ptr->ctx != nullptr; d_ptr->modelLoaded = d_ptr->ctx != nullptr;
@ -728,8 +729,10 @@ bool Bert::isModelLoaded() const
return d_ptr->modelLoaded; return d_ptr->modelLoaded;
} }
size_t Bert::requiredMem(const std::string &/*modelPath*/) size_t Bert::requiredMem(const std::string &modelPath, int n_ctx)
{ {
(void)modelPath;
(void)n_ctx;
return 0; return 0;
} }

View File

@ -18,9 +18,9 @@ public:
bool supportsEmbedding() const override { return true; } bool supportsEmbedding() const override { return true; }
bool supportsCompletion() const override { return true; } bool supportsCompletion() const override { return true; }
bool loadModel(const std::string &modelPath) override; bool loadModel(const std::string &modelPath, int n_ctx) override;
bool isModelLoaded() const override; bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath) override; size_t requiredMem(const std::string &modelPath, int n_ctx) override;
size_t stateSize() const override; size_t stateSize() const override;
size_t saveState(uint8_t *dest) const override; size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override; size_t restoreState(const uint8_t *src) override;

View File

@ -676,7 +676,8 @@ GPTJ::GPTJ()
d_ptr->modelLoaded = false; d_ptr->modelLoaded = false;
} }
size_t GPTJ::requiredMem(const std::string &modelPath) { size_t GPTJ::requiredMem(const std::string &modelPath, int n_ctx) {
(void)n_ctx;
gptj_model dummy_model; gptj_model dummy_model;
gpt_vocab dummy_vocab; gpt_vocab dummy_vocab;
size_t mem_req; size_t mem_req;
@ -684,7 +685,8 @@ size_t GPTJ::requiredMem(const std::string &modelPath) {
return mem_req; return mem_req;
} }
bool GPTJ::loadModel(const std::string &modelPath) { bool GPTJ::loadModel(const std::string &modelPath, int n_ctx) {
(void)n_ctx;
std::mt19937 rng(time(NULL)); std::mt19937 rng(time(NULL));
d_ptr->rng = rng; d_ptr->rng = rng;

View File

@ -17,9 +17,9 @@ public:
bool supportsEmbedding() const override { return false; } bool supportsEmbedding() const override { return false; }
bool supportsCompletion() const override { return true; } bool supportsCompletion() const override { return true; }
bool loadModel(const std::string &modelPath) override; bool loadModel(const std::string &modelPath, int n_ctx) override;
bool isModelLoaded() const override; bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath) override; size_t requiredMem(const std::string &modelPath, int n_ctx) override;
size_t stateSize() const override; size_t stateSize() const override;
size_t saveState(uint8_t *dest) const override; size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override; size_t restoreState(const uint8_t *src) override;

View File

@ -120,7 +120,8 @@ struct llama_file_hparams {
enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16; enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16;
}; };
size_t LLamaModel::requiredMem(const std::string &modelPath) { size_t LLamaModel::requiredMem(const std::string &modelPath, int n_ctx) {
// TODO(cebtenzzre): update to GGUF
auto fin = std::ifstream(modelPath, std::ios::binary); auto fin = std::ifstream(modelPath, std::ios::binary);
fin.seekg(0, std::ios_base::end); fin.seekg(0, std::ios_base::end);
size_t filesize = fin.tellg(); size_t filesize = fin.tellg();
@ -137,40 +138,31 @@ size_t LLamaModel::requiredMem(const std::string &modelPath) {
fin.read(reinterpret_cast<char*>(&hparams.n_layer), sizeof(hparams.n_layer)); fin.read(reinterpret_cast<char*>(&hparams.n_layer), sizeof(hparams.n_layer));
fin.read(reinterpret_cast<char*>(&hparams.n_rot), sizeof(hparams.n_rot)); fin.read(reinterpret_cast<char*>(&hparams.n_rot), sizeof(hparams.n_rot));
fin.read(reinterpret_cast<char*>(&hparams.ftype), sizeof(hparams.ftype)); fin.read(reinterpret_cast<char*>(&hparams.ftype), sizeof(hparams.ftype));
const size_t n_ctx = 2048;
const size_t kvcache_element_size = 2; // fp16 const size_t kvcache_element_size = 2; // fp16
const size_t est_kvcache_size = hparams.n_embd * hparams.n_layer * 2u * n_ctx * kvcache_element_size; const size_t est_kvcache_size = hparams.n_embd * hparams.n_layer * 2u * n_ctx * kvcache_element_size;
return filesize + est_kvcache_size; return filesize + est_kvcache_size;
} }
bool LLamaModel::loadModel(const std::string &modelPath) bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx)
{ {
gpt_params params; gpt_params params;
// load the model if (n_ctx < 8) {
std::cerr << "warning: minimum context size is 8, using minimum size.\n";
n_ctx = 8;
}
// -- load the model --
d_ptr->model_params = llama_model_default_params(); d_ptr->model_params = llama_model_default_params();
d_ptr->model_params.use_mmap = params.use_mmap; d_ptr->model_params.use_mmap = params.use_mmap;
#if defined (__APPLE__) #if defined (__APPLE__)
d_ptr->model_params.use_mlock = true; d_ptr->model_params.use_mlock = true;
#else #else
d_ptr->model_params.use_mlock = params.use_mlock; d_ptr->model_params.use_mlock = params.use_mlock;
#endif #endif
d_ptr->ctx_params = llama_context_default_params();
d_ptr->ctx_params.n_ctx = 2048;
d_ptr->ctx_params.seed = params.seed;
d_ptr->ctx_params.f16_kv = params.memory_f16;
// The new batch API provides space for n_vocab*n_tokens logits. Tell llama.cpp early
// that we want this many logits so the state serializes consistently.
d_ptr->ctx_params.logits_all = true;
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
d_ptr->ctx_params.n_threads = d_ptr->n_threads;
d_ptr->ctx_params.n_threads_batch = d_ptr->n_threads;
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
if (llama_verbose()) { if (llama_verbose()) {
std::cerr << "llama.cpp: using Metal" << std::endl; std::cerr << "llama.cpp: using Metal" << std::endl;
@ -197,6 +189,28 @@ bool LLamaModel::loadModel(const std::string &modelPath)
return false; return false;
} }
const int n_ctx_train = llama_n_ctx_train(d_ptr->model);
if (n_ctx > n_ctx_train) {
std::cerr << "warning: model was trained on only " << n_ctx_train << " context tokens ("
<< n_ctx << " specified)\n";
}
// -- initialize the context --
d_ptr->ctx_params = llama_context_default_params();
d_ptr->ctx_params.n_ctx = n_ctx;
d_ptr->ctx_params.seed = params.seed;
d_ptr->ctx_params.f16_kv = params.memory_f16;
// The new batch API provides space for n_vocab*n_tokens logits. Tell llama.cpp early
// that we want this many logits so the state serializes consistently.
d_ptr->ctx_params.logits_all = true;
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
d_ptr->ctx_params.n_threads = d_ptr->n_threads;
d_ptr->ctx_params.n_threads_batch = d_ptr->n_threads;
d_ptr->ctx = llama_new_context_with_model(d_ptr->model, d_ptr->ctx_params); d_ptr->ctx = llama_new_context_with_model(d_ptr->model, d_ptr->ctx_params);
if (!d_ptr->ctx) { if (!d_ptr->ctx) {
#ifdef GGML_USE_KOMPUTE #ifdef GGML_USE_KOMPUTE

View File

@ -17,9 +17,9 @@ public:
bool supportsEmbedding() const override { return false; } bool supportsEmbedding() const override { return false; }
bool supportsCompletion() const override { return true; } bool supportsCompletion() const override { return true; }
bool loadModel(const std::string &modelPath) override; bool loadModel(const std::string &modelPath, int n_ctx) override;
bool isModelLoaded() const override; bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath) override; size_t requiredMem(const std::string &modelPath, int n_ctx) override;
size_t stateSize() const override; size_t stateSize() const override;
size_t saveState(uint8_t *dest) const override; size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override; size_t restoreState(const uint8_t *src) override;

View File

@ -138,7 +138,7 @@ const LLModel::Implementation* LLModel::Implementation::implementation(const cha
return nullptr; return nullptr;
} }
LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::string buildVariant) { LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::string buildVariant, int n_ctx) {
if (!has_at_least_minimal_hardware()) { if (!has_at_least_minimal_hardware()) {
std::cerr << "LLModel ERROR: CPU does not support AVX\n"; std::cerr << "LLModel ERROR: CPU does not support AVX\n";
return nullptr; return nullptr;
@ -154,7 +154,11 @@ LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::s
if(impl) { if(impl) {
LLModel* metalimpl = impl->m_construct(); LLModel* metalimpl = impl->m_construct();
metalimpl->m_implementation = impl; metalimpl->m_implementation = impl;
size_t req_mem = metalimpl->requiredMem(modelPath); /* TODO(cebtenzzre): after we fix requiredMem, we should change this to happen at
* load time, not construct time. right now n_ctx is incorrectly hardcoded 2048 in
* most (all?) places where this is called, causing underestimation of required
* memory. */
size_t req_mem = metalimpl->requiredMem(modelPath, n_ctx);
float req_to_total = (float) req_mem / (float) total_mem; float req_to_total = (float) req_mem / (float) total_mem;
// on a 16GB M2 Mac a 13B q4_0 (0.52) works for me but a 13B q4_K_M (0.55) does not // on a 16GB M2 Mac a 13B q4_0 (0.52) works for me but a 13B q4_K_M (0.55) does not
if (req_to_total >= 0.53) { if (req_to_total >= 0.53) {
@ -165,6 +169,8 @@ LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::s
} }
} }
} }
#else
(void)n_ctx;
#endif #endif
if (!impl) { if (!impl) {

View File

@ -37,7 +37,7 @@ public:
static bool isImplementation(const Dlhandle&); static bool isImplementation(const Dlhandle&);
static const std::vector<Implementation>& implementationList(); static const std::vector<Implementation>& implementationList();
static const Implementation *implementation(const char *fname, const std::string& buildVariant); static const Implementation *implementation(const char *fname, const std::string& buildVariant);
static LLModel *construct(const std::string &modelPath, std::string buildVariant = "auto"); static LLModel *construct(const std::string &modelPath, std::string buildVariant = "auto", int n_ctx = 2048);
static std::vector<GPUDevice> availableGPUDevices(); static std::vector<GPUDevice> availableGPUDevices();
static void setImplementationsSearchPath(const std::string& path); static void setImplementationsSearchPath(const std::string& path);
static const std::string& implementationsSearchPath(); static const std::string& implementationsSearchPath();
@ -74,9 +74,9 @@ public:
virtual bool supportsEmbedding() const = 0; virtual bool supportsEmbedding() const = 0;
virtual bool supportsCompletion() const = 0; virtual bool supportsCompletion() const = 0;
virtual bool loadModel(const std::string &modelPath) = 0; virtual bool loadModel(const std::string &modelPath, int n_ctx) = 0;
virtual bool isModelLoaded() const = 0; virtual bool isModelLoaded() const = 0;
virtual size_t requiredMem(const std::string &modelPath) = 0; virtual size_t requiredMem(const std::string &modelPath, int n_ctx) = 0;
virtual size_t stateSize() const { return 0; } virtual size_t stateSize() const { return 0; }
virtual size_t saveState(uint8_t */*dest*/) const { return 0; } virtual size_t saveState(uint8_t */*dest*/) const { return 0; }
virtual size_t restoreState(const uint8_t */*src*/) { return 0; } virtual size_t restoreState(const uint8_t */*src*/) { return 0; }

View File

@ -47,16 +47,16 @@ void llmodel_model_destroy(llmodel_model model) {
delete reinterpret_cast<LLModelWrapper*>(model); delete reinterpret_cast<LLModelWrapper*>(model);
} }
size_t llmodel_required_mem(llmodel_model model, const char *model_path) size_t llmodel_required_mem(llmodel_model model, const char *model_path, int n_ctx)
{ {
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model); LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
return wrapper->llModel->requiredMem(model_path); return wrapper->llModel->requiredMem(model_path, n_ctx);
} }
bool llmodel_loadModel(llmodel_model model, const char *model_path) bool llmodel_loadModel(llmodel_model model, const char *model_path, int n_ctx)
{ {
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model); LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
return wrapper->llModel->loadModel(model_path); return wrapper->llModel->loadModel(model_path, n_ctx);
} }
bool llmodel_isModelLoaded(llmodel_model model) bool llmodel_isModelLoaded(llmodel_model model)

View File

@ -110,17 +110,19 @@ void llmodel_model_destroy(llmodel_model model);
* Estimate RAM requirement for a model file * Estimate RAM requirement for a model file
* @param model A pointer to the llmodel_model instance. * @param model A pointer to the llmodel_model instance.
* @param model_path A string representing the path to the model file. * @param model_path A string representing the path to the model file.
* @param n_ctx Maximum size of context window
* @return size greater than 0 if the model was parsed successfully, 0 if file could not be parsed. * @return size greater than 0 if the model was parsed successfully, 0 if file could not be parsed.
*/ */
size_t llmodel_required_mem(llmodel_model model, const char *model_path); size_t llmodel_required_mem(llmodel_model model, const char *model_path, int n_ctx);
/** /**
* Load a model from a file. * Load a model from a file.
* @param model A pointer to the llmodel_model instance. * @param model A pointer to the llmodel_model instance.
* @param model_path A string representing the path to the model file. * @param model_path A string representing the path to the model file.
* @param n_ctx Maximum size of context window
* @return true if the model was loaded successfully, false otherwise. * @return true if the model was loaded successfully, false otherwise.
*/ */
bool llmodel_loadModel(llmodel_model model, const char *model_path); bool llmodel_loadModel(llmodel_model model, const char *model_path, int n_ctx);
/** /**
* Check if a model is loaded. * Check if a model is loaded.

View File

@ -188,7 +188,7 @@ public class LLModel : ILLModel
/// <returns>true if the model was loaded successfully, false otherwise.</returns> /// <returns>true if the model was loaded successfully, false otherwise.</returns>
public bool Load(string modelPath) public bool Load(string modelPath)
{ {
return NativeMethods.llmodel_loadModel(_handle, modelPath); return NativeMethods.llmodel_loadModel(_handle, modelPath, 2048);
} }
protected void Destroy() protected void Destroy()

View File

@ -70,7 +70,8 @@ internal static unsafe partial class NativeMethods
[return: MarshalAs(UnmanagedType.I1)] [return: MarshalAs(UnmanagedType.I1)]
public static extern bool llmodel_loadModel( public static extern bool llmodel_loadModel(
[NativeTypeName("llmodel_model")] IntPtr model, [NativeTypeName("llmodel_model")] IntPtr model,
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path); [NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path,
[NativeTypeName("int32_t")] int n_ctx);
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] [DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]

View File

@ -39,7 +39,7 @@ public class Gpt4AllModelFactory : IGpt4AllModelFactory
var handle = NativeMethods.llmodel_model_create2(modelPath, "auto", out error); var handle = NativeMethods.llmodel_model_create2(modelPath, "auto", out error);
_logger.LogDebug("Model created handle=0x{ModelHandle:X8}", handle); _logger.LogDebug("Model created handle=0x{ModelHandle:X8}", handle);
_logger.LogInformation("Model loading started"); _logger.LogInformation("Model loading started");
var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath); var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath, 2048);
_logger.LogInformation("Model loading completed success={ModelLoadSuccess}", loadedSuccessfully); _logger.LogInformation("Model loading completed success={ModelLoadSuccess}", loadedSuccessfully);
if (!loadedSuccessfully) if (!loadedSuccessfully)
{ {

View File

@ -23,7 +23,7 @@ void* load_model(const char *fname, int n_threads) {
fprintf(stderr, "%s: error '%s'\n", __func__, new_error); fprintf(stderr, "%s: error '%s'\n", __func__, new_error);
return nullptr; return nullptr;
} }
if (!llmodel_loadModel(model, fname)) { if (!llmodel_loadModel(model, fname, 2048)) {
llmodel_model_destroy(model); llmodel_model_destroy(model);
return nullptr; return nullptr;
} }

View File

@ -195,7 +195,7 @@ public class LLModel implements AutoCloseable {
if(model == null) { if(model == null) {
throw new IllegalStateException("Could not load, gpt4all backend returned error: " + error.getValue().getString(0)); throw new IllegalStateException("Could not load, gpt4all backend returned error: " + error.getValue().getString(0));
} }
library.llmodel_loadModel(model, modelPathAbs); library.llmodel_loadModel(model, modelPathAbs, 2048);
if(!library.llmodel_isModelLoaded(model)){ if(!library.llmodel_isModelLoaded(model)){
throw new IllegalStateException("The model " + modelName + " could not be loaded"); throw new IllegalStateException("The model " + modelName + " could not be loaded");

View File

@ -61,7 +61,7 @@ public interface LLModelLibrary {
Pointer llmodel_model_create2(String model_path, String build_variant, PointerByReference error); Pointer llmodel_model_create2(String model_path, String build_variant, PointerByReference error);
void llmodel_model_destroy(Pointer model); void llmodel_model_destroy(Pointer model);
boolean llmodel_loadModel(Pointer model, String model_path); boolean llmodel_loadModel(Pointer model, String model_path, int n_ctx);
boolean llmodel_isModelLoaded(Pointer model); boolean llmodel_isModelLoaded(Pointer model);
@u_int64_t long llmodel_get_state_size(Pointer model); @u_int64_t long llmodel_get_state_size(Pointer model);
@u_int64_t long llmodel_save_state_data(Pointer model, Pointer dest); @u_int64_t long llmodel_save_state_data(Pointer model, Pointer dest);

View File

@ -1,2 +1,2 @@
from .gpt4all import Embed4All, GPT4All # noqa from .gpt4all import Embed4All as Embed4All, GPT4All as GPT4All
from .pyllmodel import LLModel # noqa from .pyllmodel import LLModel as LLModel

View File

@ -69,6 +69,7 @@ class GPT4All:
allow_download: bool = True, allow_download: bool = True,
n_threads: Optional[int] = None, n_threads: Optional[int] = None,
device: Optional[str] = "cpu", device: Optional[str] = "cpu",
n_ctx: int = 2048,
verbose: bool = False, verbose: bool = False,
): ):
""" """
@ -90,15 +91,16 @@ class GPT4All:
Default is "cpu". Default is "cpu".
Note: If a selected GPU device does not have sufficient RAM to accommodate the model, an error will be thrown, and the GPT4All instance will be rendered invalid. It's advised to ensure the device has enough memory before initiating the model. Note: If a selected GPU device does not have sufficient RAM to accommodate the model, an error will be thrown, and the GPT4All instance will be rendered invalid. It's advised to ensure the device has enough memory before initiating the model.
n_ctx: Maximum size of context window
verbose: If True, print debug messages.
""" """
self.model_type = model_type self.model_type = model_type
self.model = pyllmodel.LLModel() self.model = pyllmodel.LLModel()
# Retrieve model and download if allowed # Retrieve model and download if allowed
self.config: ConfigType = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download, verbose=verbose) self.config: ConfigType = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download, verbose=verbose)
if device is not None: if device is not None and device != "cpu":
if device != "cpu": self.model.init_gpu(model_path=self.config["path"], device=device, n_ctx=n_ctx)
self.model.init_gpu(model_path=self.config["path"], device=device) self.model.load_model(self.config["path"], n_ctx)
self.model.load_model(self.config["path"])
# Set n_threads # Set n_threads
if n_threads is not None: if n_threads is not None:
self.model.set_thread_count(n_threads) self.model.set_thread_count(n_threads)

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import ctypes import ctypes
import importlib.resources import importlib.resources
import logging import logging
@ -7,6 +9,7 @@ import re
import subprocess import subprocess
import sys import sys
import threading import threading
from enum import Enum
from queue import Queue from queue import Queue
from typing import Callable, Iterable, List from typing import Callable, Iterable, List
@ -72,9 +75,9 @@ llmodel.llmodel_model_create2.restype = ctypes.c_void_p
llmodel.llmodel_model_destroy.argtypes = [ctypes.c_void_p] llmodel.llmodel_model_destroy.argtypes = [ctypes.c_void_p]
llmodel.llmodel_model_destroy.restype = None llmodel.llmodel_model_destroy.restype = None
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p] llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_int]
llmodel.llmodel_loadModel.restype = ctypes.c_bool llmodel.llmodel_loadModel.restype = ctypes.c_bool
llmodel.llmodel_required_mem.argtypes = [ctypes.c_void_p, ctypes.c_char_p] llmodel.llmodel_required_mem.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_int]
llmodel.llmodel_required_mem.restype = ctypes.c_size_t llmodel.llmodel_required_mem.restype = ctypes.c_size_t
llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p] llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p]
llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
@ -114,7 +117,7 @@ llmodel.llmodel_set_implementation_search_path.restype = None
llmodel.llmodel_threadCount.argtypes = [ctypes.c_void_p] llmodel.llmodel_threadCount.argtypes = [ctypes.c_void_p]
llmodel.llmodel_threadCount.restype = ctypes.c_int32 llmodel.llmodel_threadCount.restype = ctypes.c_int32
llmodel.llmodel_set_implementation_search_path(str(MODEL_LIB_PATH).replace("\\", r"\\").encode("utf-8")) llmodel.llmodel_set_implementation_search_path(str(MODEL_LIB_PATH).replace("\\", r"\\").encode())
llmodel.llmodel_available_gpu_devices.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.POINTER(ctypes.c_int32)] llmodel.llmodel_available_gpu_devices.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.POINTER(ctypes.c_int32)]
llmodel.llmodel_available_gpu_devices.restype = ctypes.POINTER(LLModelGPUDevice) llmodel.llmodel_available_gpu_devices.restype = ctypes.POINTER(LLModelGPUDevice)
@ -143,10 +146,16 @@ def _create_model(model_path: bytes) -> ctypes.c_void_p:
err = ctypes.c_char_p() err = ctypes.c_char_p()
model = llmodel.llmodel_model_create2(model_path, b"auto", ctypes.byref(err)) model = llmodel.llmodel_model_create2(model_path, b"auto", ctypes.byref(err))
if model is None: if model is None:
raise ValueError(f"Unable to instantiate model: {err.decode()}") s = err.value
raise ValueError("Unable to instantiate model: {'null' if s is None else s.decode()}")
return model return model
# Symbol to terminate from generator
class Sentinel(Enum):
TERMINATING_SYMBOL = 0
class LLModel: class LLModel:
""" """
Base class and universal wrapper for GPT4All language models Base class and universal wrapper for GPT4All language models
@ -173,12 +182,16 @@ class LLModel:
if self.model is not None: if self.model is not None:
self.llmodel_lib.llmodel_model_destroy(self.model) self.llmodel_lib.llmodel_model_destroy(self.model)
def memory_needed(self, model_path: str) -> int: def memory_needed(self, model_path: str, n_ctx: int) -> int:
model_path_enc = model_path.encode("utf-8") self.model = None
self.model = _create_model(model_path_enc) return self._memory_needed(model_path, n_ctx)
return llmodel.llmodel_required_mem(self.model, model_path_enc)
def list_gpu(self, model_path: str) -> list: def _memory_needed(self, model_path: str, n_ctx: int) -> int:
if self.model is None:
self.model = _create_model(model_path.encode())
return llmodel.llmodel_required_mem(self.model, model_path.encode(), n_ctx)
def list_gpu(self, model_path: str, n_ctx: int) -> list[LLModelGPUDevice]:
""" """
Lists available GPU devices that satisfy the model's memory requirements. Lists available GPU devices that satisfy the model's memory requirements.
@ -186,45 +199,41 @@ class LLModel:
---------- ----------
model_path : str model_path : str
Path to the model. Path to the model.
n_ctx : int
Maximum size of context window
Returns Returns
------- -------
list list
A list of LLModelGPUDevice structures representing available GPU devices. A list of LLModelGPUDevice structures representing available GPU devices.
""" """
if self.model is not None: mem_required = self._memory_needed(model_path, n_ctx)
model_path_enc = model_path.encode("utf-8") return self._list_gpu(mem_required)
mem_required = llmodel.llmodel_required_mem(self.model, model_path_enc)
else: def _list_gpu(self, mem_required: int) -> list[LLModelGPUDevice]:
mem_required = self.memory_needed(model_path)
num_devices = ctypes.c_int32(0) num_devices = ctypes.c_int32(0)
devices_ptr = self.llmodel_lib.llmodel_available_gpu_devices(self.model, mem_required, ctypes.byref(num_devices)) devices_ptr = self.llmodel_lib.llmodel_available_gpu_devices(self.model, mem_required, ctypes.byref(num_devices))
if not devices_ptr: if not devices_ptr:
raise ValueError("Unable to retrieve available GPU devices") raise ValueError("Unable to retrieve available GPU devices")
devices = [devices_ptr[i] for i in range(num_devices.value)] return devices_ptr[:num_devices.value]
return devices
def init_gpu(self, model_path: str, device: str): def init_gpu(self, model_path: str, device: str, n_ctx: int):
if self.model is not None: mem_required = self._memory_needed(model_path, n_ctx)
model_path_enc = model_path.encode("utf-8")
mem_required = llmodel.llmodel_required_mem(self.model, model_path_enc) success = self.llmodel_lib.llmodel_gpu_init_gpu_device_by_string(self.model, mem_required, device.encode())
else:
mem_required = self.memory_needed(model_path)
device_enc = device.encode("utf-8")
success = self.llmodel_lib.llmodel_gpu_init_gpu_device_by_string(self.model, mem_required, device_enc)
if not success: if not success:
# Retrieve all GPUs without considering memory requirements. # Retrieve all GPUs without considering memory requirements.
num_devices = ctypes.c_int32(0) num_devices = ctypes.c_int32(0)
all_devices_ptr = self.llmodel_lib.llmodel_available_gpu_devices(self.model, 0, ctypes.byref(num_devices)) all_devices_ptr = self.llmodel_lib.llmodel_available_gpu_devices(self.model, 0, ctypes.byref(num_devices))
if not all_devices_ptr: if not all_devices_ptr:
raise ValueError("Unable to retrieve list of all GPU devices") raise ValueError("Unable to retrieve list of all GPU devices")
all_gpus = [all_devices_ptr[i].name.decode('utf-8') for i in range(num_devices.value)] all_gpus = [d.name.decode() for d in all_devices_ptr[:num_devices.value]]
# Retrieve GPUs that meet the memory requirements using list_gpu # Retrieve GPUs that meet the memory requirements using list_gpu
available_gpus = [device.name.decode('utf-8') for device in self.list_gpu(model_path)] available_gpus = [device.name.decode() for device in self._list_gpu(mem_required)]
# Identify GPUs that are unavailable due to insufficient memory or features # Identify GPUs that are unavailable due to insufficient memory or features
unavailable_gpus = set(all_gpus) - set(available_gpus) unavailable_gpus = set(all_gpus).difference(available_gpus)
# Formulate the error message # Formulate the error message
error_msg = "Unable to initialize model on GPU: '{}'.".format(device) error_msg = "Unable to initialize model on GPU: '{}'.".format(device)
@ -232,7 +241,7 @@ class LLModel:
error_msg += "\nUnavailable GPUs due to insufficient memory or features: {}.".format(unavailable_gpus) error_msg += "\nUnavailable GPUs due to insufficient memory or features: {}.".format(unavailable_gpus)
raise ValueError(error_msg) raise ValueError(error_msg)
def load_model(self, model_path: str) -> bool: def load_model(self, model_path: str, n_ctx: int) -> bool:
""" """
Load model from a file. Load model from a file.
@ -240,15 +249,16 @@ class LLModel:
---------- ----------
model_path : str model_path : str
Model filepath Model filepath
n_ctx : int
Maximum size of context window
Returns Returns
------- -------
True if model loaded successfully, False otherwise True if model loaded successfully, False otherwise
""" """
model_path_enc = model_path.encode("utf-8") self.model = _create_model(model_path.encode())
self.model = _create_model(model_path_enc)
llmodel.llmodel_loadModel(self.model, model_path_enc) llmodel.llmodel_loadModel(self.model, model_path.encode(), n_ctx)
filename = os.path.basename(model_path) filename = os.path.basename(model_path)
self.model_name = os.path.splitext(filename)[0] self.model_name = os.path.splitext(filename)[0]
@ -312,7 +322,7 @@ class LLModel:
raise ValueError("Text must not be None or empty") raise ValueError("Text must not be None or empty")
embedding_size = ctypes.c_size_t() embedding_size = ctypes.c_size_t()
c_text = ctypes.c_char_p(text.encode('utf-8')) c_text = ctypes.c_char_p(text.encode())
embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size)) embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size))
embedding_array = [embedding_ptr[i] for i in range(embedding_size.value)] embedding_array = [embedding_ptr[i] for i in range(embedding_size.value)]
llmodel.llmodel_free_embedding(embedding_ptr) llmodel.llmodel_free_embedding(embedding_ptr)
@ -357,7 +367,7 @@ class LLModel:
prompt, prompt,
) )
prompt_bytes = prompt.encode("utf-8") prompt_bytes = prompt.encode()
prompt_ptr = ctypes.c_char_p(prompt_bytes) prompt_ptr = ctypes.c_char_p(prompt_bytes)
self._set_context( self._set_context(
@ -385,10 +395,7 @@ class LLModel:
def prompt_model_streaming( def prompt_model_streaming(
self, prompt: str, callback: ResponseCallbackType = empty_response_callback, **kwargs self, prompt: str, callback: ResponseCallbackType = empty_response_callback, **kwargs
) -> Iterable[str]: ) -> Iterable[str]:
# Symbol to terminate from generator output_queue: Queue[str | Sentinel] = Queue()
TERMINATING_SYMBOL = object()
output_queue: Queue = Queue()
# Put response tokens into an output queue # Put response tokens into an output queue
def _generator_callback_wrapper(callback: ResponseCallbackType) -> ResponseCallbackType: def _generator_callback_wrapper(callback: ResponseCallbackType) -> ResponseCallbackType:
@ -405,7 +412,7 @@ class LLModel:
def run_llmodel_prompt(prompt: str, callback: ResponseCallbackType, **kwargs): def run_llmodel_prompt(prompt: str, callback: ResponseCallbackType, **kwargs):
self.prompt_model(prompt, callback, **kwargs) self.prompt_model(prompt, callback, **kwargs)
output_queue.put(TERMINATING_SYMBOL) output_queue.put(Sentinel.TERMINATING_SYMBOL)
# Kick off llmodel_prompt in separate thread so we can return generator # Kick off llmodel_prompt in separate thread so we can return generator
# immediately # immediately
@ -419,7 +426,7 @@ class LLModel:
# Generator # Generator
while True: while True:
response = output_queue.get() response = output_queue.get()
if response is TERMINATING_SYMBOL: if isinstance(response, Sentinel):
break break
yield response yield response
@ -442,7 +449,7 @@ class LLModel:
else: else:
# beginning of a byte sequence # beginning of a byte sequence
if len(self.buffer) > 0: if len(self.buffer) > 0:
decoded.append(self.buffer.decode('utf-8', 'replace')) decoded.append(self.buffer.decode(errors='replace'))
self.buffer.clear() self.buffer.clear()
@ -451,7 +458,7 @@ class LLModel:
if self.buff_expecting_cont_bytes <= 0: if self.buff_expecting_cont_bytes <= 0:
# received the whole sequence or an out of place continuation byte # received the whole sequence or an out of place continuation byte
decoded.append(self.buffer.decode('utf-8', 'replace')) decoded.append(self.buffer.decode(errors='replace'))
self.buffer.clear() self.buffer.clear()
self.buff_expecting_cont_bytes = 0 self.buff_expecting_cont_bytes = 0

View File

@ -117,7 +117,7 @@ def test_empty_embedding():
def test_download_model(tmp_path: Path): def test_download_model(tmp_path: Path):
import gpt4all.gpt4all import gpt4all.gpt4all
old_default_dir = gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY old_default_dir = gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY
gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY = tmp_path # temporary pytest directory to ensure a download happens gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY = str(tmp_path) # temporary pytest directory to ensure a download happens
try: try:
model = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin') model = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin')
model_path = tmp_path / model.config['filename'] model_path = tmp_path / model.config['filename']

View File

@ -28,7 +28,7 @@ Napi::Function NodeModelWrapper::GetClass(Napi::Env env) {
Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info) Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
{ {
auto env = info.Env(); auto env = info.Env();
return Napi::Number::New(env, static_cast<uint32_t>( llmodel_required_mem(GetInference(), full_model_path.c_str()) )); return Napi::Number::New(env, static_cast<uint32_t>( llmodel_required_mem(GetInference(), full_model_path.c_str(), 2048) ));
} }
Napi::Value NodeModelWrapper::GetGpuDevices(const Napi::CallbackInfo& info) Napi::Value NodeModelWrapper::GetGpuDevices(const Napi::CallbackInfo& info)
@ -161,7 +161,7 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
} }
} }
auto success = llmodel_loadModel(GetInference(), full_weight_path.c_str()); auto success = llmodel_loadModel(GetInference(), full_weight_path.c_str(), 2048);
if(!success) { if(!success) {
Napi::Error::New(env, "Failed to load model at given path").ThrowAsJavaScriptException(); Napi::Error::New(env, "Failed to load model at given path").ThrowAsJavaScriptException();
return; return;

View File

@ -20,15 +20,17 @@ ChatGPT::ChatGPT()
{ {
} }
size_t ChatGPT::requiredMem(const std::string &modelPath) size_t ChatGPT::requiredMem(const std::string &modelPath, int n_ctx)
{ {
Q_UNUSED(modelPath); Q_UNUSED(modelPath);
Q_UNUSED(n_ctx);
return 0; return 0;
} }
bool ChatGPT::loadModel(const std::string &modelPath) bool ChatGPT::loadModel(const std::string &modelPath, int n_ctx)
{ {
Q_UNUSED(modelPath); Q_UNUSED(modelPath);
Q_UNUSED(n_ctx);
return true; return true;
} }

View File

@ -48,9 +48,9 @@ public:
bool supportsEmbedding() const override { return false; } bool supportsEmbedding() const override { return false; }
bool supportsCompletion() const override { return true; } bool supportsCompletion() const override { return true; }
bool loadModel(const std::string &modelPath) override; bool loadModel(const std::string &modelPath, int n_ctx) override;
bool isModelLoaded() const override; bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath) override; size_t requiredMem(const std::string &modelPath, int n_ctx) override;
size_t stateSize() const override; size_t stateSize() const override;
size_t saveState(uint8_t *dest) const override; size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override; size_t restoreState(const uint8_t *src) override;

View File

@ -5,7 +5,7 @@
#include <QDataStream> #include <QDataStream>
#define CHAT_FORMAT_MAGIC 0xF5D553CC #define CHAT_FORMAT_MAGIC 0xF5D553CC
#define CHAT_FORMAT_VERSION 6 #define CHAT_FORMAT_VERSION 7
class MyChatListModel: public ChatListModel { }; class MyChatListModel: public ChatListModel { };
Q_GLOBAL_STATIC(MyChatListModel, chatListModelInstance) Q_GLOBAL_STATIC(MyChatListModel, chatListModelInstance)

View File

@ -248,14 +248,16 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
m_llModelInfo.model = model; m_llModelInfo.model = model;
} else { } else {
// TODO: make configurable in UI
auto n_ctx = MySettings::globalInstance()->modelContextLength(modelInfo);
m_ctx.n_ctx = n_ctx;
std::string buildVariant = "auto";
#if defined(Q_OS_MAC) && defined(__arm__) #if defined(Q_OS_MAC) && defined(__arm__)
if (m_forceMetal) if (m_forceMetal)
m_llModelInfo.model = LLMImplementation::construct(filePath.toStdString(), "metal"); buildVariant = "metal";
else
m_llModelInfo.model = LLMImplementation::construct(filePath.toStdString(), "auto");
#else
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), "auto");
#endif #endif
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx);
if (m_llModelInfo.model) { if (m_llModelInfo.model) {
// Update the settings that a model is being loaded and update the device list // Update the settings that a model is being loaded and update the device list
@ -267,7 +269,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
if (requestedDevice == "CPU") { if (requestedDevice == "CPU") {
emit reportFallbackReason(""); // fallback not applicable emit reportFallbackReason(""); // fallback not applicable
} else { } else {
const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString()); const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString(), n_ctx);
std::vector<LLModel::GPUDevice> availableDevices = m_llModelInfo.model->availableGPUDevices(requiredMemory); std::vector<LLModel::GPUDevice> availableDevices = m_llModelInfo.model->availableGPUDevices(requiredMemory);
LLModel::GPUDevice *device = nullptr; LLModel::GPUDevice *device = nullptr;
@ -296,14 +298,14 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
// Report which device we're actually using // Report which device we're actually using
emit reportDevice(actualDevice); emit reportDevice(actualDevice);
bool success = m_llModelInfo.model->loadModel(filePath.toStdString()); bool success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx);
if (actualDevice == "CPU") { if (actualDevice == "CPU") {
// we asked llama.cpp to use the CPU // we asked llama.cpp to use the CPU
} else if (!success) { } else if (!success) {
// llama_init_from_file returned nullptr // llama_init_from_file returned nullptr
emit reportDevice("CPU"); emit reportDevice("CPU");
emit reportFallbackReason("<br>GPU loading failed (out of VRAM?)"); emit reportFallbackReason("<br>GPU loading failed (out of VRAM?)");
success = m_llModelInfo.model->loadModel(filePath.toStdString()); success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx);
} else if (!m_llModelInfo.model->usingGPUDevice()) { } else if (!m_llModelInfo.model->usingGPUDevice()) {
// ggml_vk_init was not called in llama.cpp // ggml_vk_init was not called in llama.cpp
// We might have had to fallback to CPU after load if the model is not possible to accelerate // We might have had to fallback to CPU after load if the model is not possible to accelerate
@ -763,6 +765,8 @@ bool ChatLLM::handleRestoreStateFromTextRecalculate(bool isRecalc)
return false; return false;
} }
// this function serialized the cached model state to disk.
// we want to also serialize n_ctx, and read it at load time.
bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV) bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
{ {
if (version > 1) { if (version > 1) {
@ -790,6 +794,9 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
stream << responseLogits; stream << responseLogits;
} }
stream << m_ctx.n_past; stream << m_ctx.n_past;
if (version >= 6) {
stream << m_ctx.n_ctx;
}
stream << quint64(m_ctx.logits.size()); stream << quint64(m_ctx.logits.size());
stream.writeRawData(reinterpret_cast<const char*>(m_ctx.logits.data()), m_ctx.logits.size() * sizeof(float)); stream.writeRawData(reinterpret_cast<const char*>(m_ctx.logits.data()), m_ctx.logits.size() * sizeof(float));
stream << quint64(m_ctx.tokens.size()); stream << quint64(m_ctx.tokens.size());
@ -839,6 +846,12 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
stream >> n_past; stream >> n_past;
if (!discardKV) m_ctx.n_past = n_past; if (!discardKV) m_ctx.n_past = n_past;
if (version >= 6) {
uint32_t n_ctx;
stream >> n_ctx;
if (!discardKV) m_ctx.n_ctx = n_ctx;
}
quint64 logitsSize; quint64 logitsSize;
stream >> logitsSize; stream >> logitsSize;
if (!discardKV) { if (!discardKV) {

View File

@ -29,8 +29,8 @@ bool EmbeddingLLM::loadModel()
return false; return false;
} }
m_model = LLModel::Implementation::construct(filePath.toStdString(), "auto"); m_model = LLModel::Implementation::construct(filePath.toStdString());
bool success = m_model->loadModel(filePath.toStdString()); bool success = m_model->loadModel(filePath.toStdString(), 2048);
if (!success) { if (!success) {
qWarning() << "WARNING: Could not load sbert"; qWarning() << "WARNING: Could not load sbert";
delete m_model; delete m_model;

View File

@ -97,6 +97,17 @@ void ModelInfo::setPromptBatchSize(int s)
m_promptBatchSize = s; m_promptBatchSize = s;
} }
int ModelInfo::contextLength() const
{
return MySettings::globalInstance()->modelContextLength(*this);
}
void ModelInfo::setContextLength(int l)
{
if (isClone) MySettings::globalInstance()->setModelContextLength(*this, l, isClone /*force*/);
m_contextLength = l;
}
double ModelInfo::repeatPenalty() const double ModelInfo::repeatPenalty() const
{ {
return MySettings::globalInstance()->modelRepeatPenalty(*this); return MySettings::globalInstance()->modelRepeatPenalty(*this);
@ -274,6 +285,7 @@ ModelList::ModelList()
connect(MySettings::globalInstance(), &MySettings::topKChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::topKChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::maxLengthChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::maxLengthChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::promptBatchSizeChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::promptBatchSizeChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::contextLengthChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::repeatPenaltyChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::repeatPenaltyChanged, this, &ModelList::updateDataForSettings);
connect(MySettings::globalInstance(), &MySettings::repeatPenaltyTokensChanged, this, &ModelList::updateDataForSettings);; connect(MySettings::globalInstance(), &MySettings::repeatPenaltyTokensChanged, this, &ModelList::updateDataForSettings);;
connect(MySettings::globalInstance(), &MySettings::promptTemplateChanged, this, &ModelList::updateDataForSettings); connect(MySettings::globalInstance(), &MySettings::promptTemplateChanged, this, &ModelList::updateDataForSettings);
@ -525,6 +537,8 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const
return info->maxLength(); return info->maxLength();
case PromptBatchSizeRole: case PromptBatchSizeRole:
return info->promptBatchSize(); return info->promptBatchSize();
case ContextLengthRole:
return info->contextLength();
case RepeatPenaltyRole: case RepeatPenaltyRole:
return info->repeatPenalty(); return info->repeatPenalty();
case RepeatPenaltyTokensRole: case RepeatPenaltyTokensRole:
@ -740,6 +754,7 @@ QString ModelList::clone(const ModelInfo &model)
updateData(id, ModelList::TopKRole, model.topK()); updateData(id, ModelList::TopKRole, model.topK());
updateData(id, ModelList::MaxLengthRole, model.maxLength()); updateData(id, ModelList::MaxLengthRole, model.maxLength());
updateData(id, ModelList::PromptBatchSizeRole, model.promptBatchSize()); updateData(id, ModelList::PromptBatchSizeRole, model.promptBatchSize());
updateData(id, ModelList::ContextLengthRole, model.contextLength());
updateData(id, ModelList::RepeatPenaltyRole, model.repeatPenalty()); updateData(id, ModelList::RepeatPenaltyRole, model.repeatPenalty());
updateData(id, ModelList::RepeatPenaltyTokensRole, model.repeatPenaltyTokens()); updateData(id, ModelList::RepeatPenaltyTokensRole, model.repeatPenaltyTokens());
updateData(id, ModelList::PromptTemplateRole, model.promptTemplate()); updateData(id, ModelList::PromptTemplateRole, model.promptTemplate());
@ -1106,6 +1121,8 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
updateData(id, ModelList::MaxLengthRole, obj["maxLength"].toInt()); updateData(id, ModelList::MaxLengthRole, obj["maxLength"].toInt());
if (obj.contains("promptBatchSize")) if (obj.contains("promptBatchSize"))
updateData(id, ModelList::PromptBatchSizeRole, obj["promptBatchSize"].toInt()); updateData(id, ModelList::PromptBatchSizeRole, obj["promptBatchSize"].toInt());
if (obj.contains("contextLength"))
updateData(id, ModelList::ContextLengthRole, obj["contextLength"].toInt());
if (obj.contains("repeatPenalty")) if (obj.contains("repeatPenalty"))
updateData(id, ModelList::RepeatPenaltyRole, obj["repeatPenalty"].toDouble()); updateData(id, ModelList::RepeatPenaltyRole, obj["repeatPenalty"].toDouble());
if (obj.contains("repeatPenaltyTokens")) if (obj.contains("repeatPenaltyTokens"))
@ -1198,6 +1215,8 @@ void ModelList::updateModelsFromSettings()
const int maxLength = settings.value(g + "/maxLength").toInt(); const int maxLength = settings.value(g + "/maxLength").toInt();
Q_ASSERT(settings.contains(g + "/promptBatchSize")); Q_ASSERT(settings.contains(g + "/promptBatchSize"));
const int promptBatchSize = settings.value(g + "/promptBatchSize").toInt(); const int promptBatchSize = settings.value(g + "/promptBatchSize").toInt();
Q_ASSERT(settings.contains(g + "/contextLength"));
const int contextLength = settings.value(g + "/contextLength").toInt();
Q_ASSERT(settings.contains(g + "/repeatPenalty")); Q_ASSERT(settings.contains(g + "/repeatPenalty"));
const double repeatPenalty = settings.value(g + "/repeatPenalty").toDouble(); const double repeatPenalty = settings.value(g + "/repeatPenalty").toDouble();
Q_ASSERT(settings.contains(g + "/repeatPenaltyTokens")); Q_ASSERT(settings.contains(g + "/repeatPenaltyTokens"));
@ -1216,6 +1235,7 @@ void ModelList::updateModelsFromSettings()
updateData(id, ModelList::TopKRole, topK); updateData(id, ModelList::TopKRole, topK);
updateData(id, ModelList::MaxLengthRole, maxLength); updateData(id, ModelList::MaxLengthRole, maxLength);
updateData(id, ModelList::PromptBatchSizeRole, promptBatchSize); updateData(id, ModelList::PromptBatchSizeRole, promptBatchSize);
updateData(id, ModelList::ContextLengthRole, contextLength);
updateData(id, ModelList::RepeatPenaltyRole, repeatPenalty); updateData(id, ModelList::RepeatPenaltyRole, repeatPenalty);
updateData(id, ModelList::RepeatPenaltyTokensRole, repeatPenaltyTokens); updateData(id, ModelList::RepeatPenaltyTokensRole, repeatPenaltyTokens);
updateData(id, ModelList::PromptTemplateRole, promptTemplate); updateData(id, ModelList::PromptTemplateRole, promptTemplate);

View File

@ -39,6 +39,7 @@ struct ModelInfo {
Q_PROPERTY(int topK READ topK WRITE setTopK) Q_PROPERTY(int topK READ topK WRITE setTopK)
Q_PROPERTY(int maxLength READ maxLength WRITE setMaxLength) Q_PROPERTY(int maxLength READ maxLength WRITE setMaxLength)
Q_PROPERTY(int promptBatchSize READ promptBatchSize WRITE setPromptBatchSize) Q_PROPERTY(int promptBatchSize READ promptBatchSize WRITE setPromptBatchSize)
Q_PROPERTY(int contextLength READ contextLength WRITE setContextLength)
Q_PROPERTY(double repeatPenalty READ repeatPenalty WRITE setRepeatPenalty) Q_PROPERTY(double repeatPenalty READ repeatPenalty WRITE setRepeatPenalty)
Q_PROPERTY(int repeatPenaltyTokens READ repeatPenaltyTokens WRITE setRepeatPenaltyTokens) Q_PROPERTY(int repeatPenaltyTokens READ repeatPenaltyTokens WRITE setRepeatPenaltyTokens)
Q_PROPERTY(QString promptTemplate READ promptTemplate WRITE setPromptTemplate) Q_PROPERTY(QString promptTemplate READ promptTemplate WRITE setPromptTemplate)
@ -94,6 +95,8 @@ public:
void setMaxLength(int l); void setMaxLength(int l);
int promptBatchSize() const; int promptBatchSize() const;
void setPromptBatchSize(int s); void setPromptBatchSize(int s);
int contextLength() const;
void setContextLength(int l);
double repeatPenalty() const; double repeatPenalty() const;
void setRepeatPenalty(double p); void setRepeatPenalty(double p);
int repeatPenaltyTokens() const; int repeatPenaltyTokens() const;
@ -112,6 +115,7 @@ private:
int m_topK = 40; int m_topK = 40;
int m_maxLength = 4096; int m_maxLength = 4096;
int m_promptBatchSize = 128; int m_promptBatchSize = 128;
int m_contextLength = 2048;
double m_repeatPenalty = 1.18; double m_repeatPenalty = 1.18;
int m_repeatPenaltyTokens = 64; int m_repeatPenaltyTokens = 64;
QString m_promptTemplate = "### Human:\n%1\n### Assistant:\n"; QString m_promptTemplate = "### Human:\n%1\n### Assistant:\n";
@ -227,6 +231,7 @@ public:
TopKRole, TopKRole,
MaxLengthRole, MaxLengthRole,
PromptBatchSizeRole, PromptBatchSizeRole,
ContextLengthRole,
RepeatPenaltyRole, RepeatPenaltyRole,
RepeatPenaltyTokensRole, RepeatPenaltyTokensRole,
PromptTemplateRole, PromptTemplateRole,
@ -269,6 +274,7 @@ public:
roles[TopKRole] = "topK"; roles[TopKRole] = "topK";
roles[MaxLengthRole] = "maxLength"; roles[MaxLengthRole] = "maxLength";
roles[PromptBatchSizeRole] = "promptBatchSize"; roles[PromptBatchSizeRole] = "promptBatchSize";
roles[ContextLengthRole] = "contextLength";
roles[RepeatPenaltyRole] = "repeatPenalty"; roles[RepeatPenaltyRole] = "repeatPenalty";
roles[RepeatPenaltyTokensRole] = "repeatPenaltyTokens"; roles[RepeatPenaltyTokensRole] = "repeatPenaltyTokens";
roles[PromptTemplateRole] = "promptTemplate"; roles[PromptTemplateRole] = "promptTemplate";

View File

@ -90,6 +90,7 @@ void MySettings::restoreModelDefaults(const ModelInfo &model)
setModelTopK(model, model.m_topK);; setModelTopK(model, model.m_topK);;
setModelMaxLength(model, model.m_maxLength); setModelMaxLength(model, model.m_maxLength);
setModelPromptBatchSize(model, model.m_promptBatchSize); setModelPromptBatchSize(model, model.m_promptBatchSize);
setModelContextLength(model, model.m_contextLength);
setModelRepeatPenalty(model, model.m_repeatPenalty); setModelRepeatPenalty(model, model.m_repeatPenalty);
setModelRepeatPenaltyTokens(model, model.m_repeatPenaltyTokens); setModelRepeatPenaltyTokens(model, model.m_repeatPenaltyTokens);
setModelPromptTemplate(model, model.m_promptTemplate); setModelPromptTemplate(model, model.m_promptTemplate);
@ -280,6 +281,28 @@ void MySettings::setModelPromptBatchSize(const ModelInfo &m, int s, bool force)
emit promptBatchSizeChanged(m); emit promptBatchSizeChanged(m);
} }
int MySettings::modelContextLength(const ModelInfo &m) const
{
QSettings setting;
setting.sync();
return setting.value(QString("model-%1").arg(m.id()) + "/contextLength", m.m_contextLength).toInt();
}
void MySettings::setModelContextLength(const ModelInfo &m, int l, bool force)
{
if (modelContextLength(m) == l && !force)
return;
QSettings setting;
if (m.m_contextLength == l && !m.isClone)
setting.remove(QString("model-%1").arg(m.id()) + "/contextLength");
else
setting.setValue(QString("model-%1").arg(m.id()) + "/contextLength", l);
setting.sync();
if (!force)
emit contextLengthChanged(m);
}
double MySettings::modelRepeatPenalty(const ModelInfo &m) const double MySettings::modelRepeatPenalty(const ModelInfo &m) const
{ {
QSettings setting; QSettings setting;

View File

@ -1,6 +1,8 @@
#ifndef MYSETTINGS_H #ifndef MYSETTINGS_H
#define MYSETTINGS_H #define MYSETTINGS_H
#include <cstdint>
#include <QObject> #include <QObject>
#include <QMutex> #include <QMutex>
@ -59,6 +61,8 @@ public:
Q_INVOKABLE void setModelPromptTemplate(const ModelInfo &m, const QString &t, bool force = false); Q_INVOKABLE void setModelPromptTemplate(const ModelInfo &m, const QString &t, bool force = false);
QString modelSystemPrompt(const ModelInfo &m) const; QString modelSystemPrompt(const ModelInfo &m) const;
Q_INVOKABLE void setModelSystemPrompt(const ModelInfo &m, const QString &p, bool force = false); Q_INVOKABLE void setModelSystemPrompt(const ModelInfo &m, const QString &p, bool force = false);
int modelContextLength(const ModelInfo &m) const;
Q_INVOKABLE void setModelContextLength(const ModelInfo &m, int s, bool force = false);
// Application settings // Application settings
int threadCount() const; int threadCount() const;
@ -79,6 +83,8 @@ public:
void setForceMetal(bool b); void setForceMetal(bool b);
QString device() const; QString device() const;
void setDevice(const QString &u); void setDevice(const QString &u);
int32_t contextLength() const;
void setContextLength(int32_t value);
// Release/Download settings // Release/Download settings
QString lastVersionStarted() const; QString lastVersionStarted() const;
@ -114,6 +120,7 @@ Q_SIGNALS:
void topKChanged(const ModelInfo &model); void topKChanged(const ModelInfo &model);
void maxLengthChanged(const ModelInfo &model); void maxLengthChanged(const ModelInfo &model);
void promptBatchSizeChanged(const ModelInfo &model); void promptBatchSizeChanged(const ModelInfo &model);
void contextLengthChanged(const ModelInfo &model);
void repeatPenaltyChanged(const ModelInfo &model); void repeatPenaltyChanged(const ModelInfo &model);
void repeatPenaltyTokensChanged(const ModelInfo &model); void repeatPenaltyTokensChanged(const ModelInfo &model);
void promptTemplateChanged(const ModelInfo &model); void promptTemplateChanged(const ModelInfo &model);

View File

@ -349,13 +349,61 @@ MySettingsTab {
rowSpacing: 10 rowSpacing: 10
columnSpacing: 10 columnSpacing: 10
Label {
id: contextLengthLabel
visible: !root.currentModelInfo.isChatGPT
text: qsTr("Context Length:")
font.pixelSize: theme.fontSizeLarge
color: theme.textColor
Layout.row: 0
Layout.column: 0
}
MyTextField {
id: contextLengthField
visible: !root.currentModelInfo.isChatGPT
text: root.currentModelInfo.contextLength
color: theme.textColor
font.pixelSize: theme.fontSizeLarge
ToolTip.text: qsTr("Maximum combined prompt/response tokens before information is lost.\nUsing more context than the model was trained on will yield poor results.\nNOTE: Does not take effect until you RESTART GPT4All or SWITCH MODELS.")
ToolTip.visible: hovered
Layout.row: 0
Layout.column: 1
validator: IntValidator {
bottom: 1
}
Connections {
target: MySettings
function onContextLengthChanged() {
contextLengthField.text = root.currentModelInfo.contextLength;
}
}
Connections {
target: root
function onCurrentModelInfoChanged() {
contextLengthField.text = root.currentModelInfo.contextLength;
}
}
onEditingFinished: {
var val = parseInt(text)
if (!isNaN(val)) {
MySettings.setModelContextLength(root.currentModelInfo, val)
focus = false
} else {
text = root.currentModelInfo.contextLength
}
}
Accessible.role: Accessible.EditableText
Accessible.name: contextLengthLabel.text
Accessible.description: ToolTip.text
}
Label { Label {
id: tempLabel id: tempLabel
text: qsTr("Temperature:") text: qsTr("Temperature:")
color: theme.textColor color: theme.textColor
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge
Layout.row: 0 Layout.row: 1
Layout.column: 0 Layout.column: 2
} }
MyTextField { MyTextField {
@ -365,8 +413,8 @@ MySettingsTab {
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge
ToolTip.text: qsTr("Temperature increases the chances of choosing less likely tokens.\nNOTE: Higher temperature gives more creative but less predictable outputs.") ToolTip.text: qsTr("Temperature increases the chances of choosing less likely tokens.\nNOTE: Higher temperature gives more creative but less predictable outputs.")
ToolTip.visible: hovered ToolTip.visible: hovered
Layout.row: 0 Layout.row: 1
Layout.column: 1 Layout.column: 3
validator: DoubleValidator { validator: DoubleValidator {
locale: "C" locale: "C"
} }
@ -400,8 +448,8 @@ MySettingsTab {
text: qsTr("Top P:") text: qsTr("Top P:")
color: theme.textColor color: theme.textColor
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge
Layout.row: 0 Layout.row: 2
Layout.column: 2 Layout.column: 0
} }
MyTextField { MyTextField {
id: topPField id: topPField
@ -410,8 +458,8 @@ MySettingsTab {
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge
ToolTip.text: qsTr("Only the most likely tokens up to a total probability of top_p can be chosen.\nNOTE: Prevents choosing highly unlikely tokens, aka Nucleus Sampling") ToolTip.text: qsTr("Only the most likely tokens up to a total probability of top_p can be chosen.\nNOTE: Prevents choosing highly unlikely tokens, aka Nucleus Sampling")
ToolTip.visible: hovered ToolTip.visible: hovered
Layout.row: 0 Layout.row: 2
Layout.column: 3 Layout.column: 1
validator: DoubleValidator { validator: DoubleValidator {
locale: "C" locale: "C"
} }
@ -446,8 +494,8 @@ MySettingsTab {
text: qsTr("Top K:") text: qsTr("Top K:")
color: theme.textColor color: theme.textColor
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge
Layout.row: 1 Layout.row: 2
Layout.column: 0 Layout.column: 2
} }
MyTextField { MyTextField {
id: topKField id: topKField
@ -457,8 +505,8 @@ MySettingsTab {
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge
ToolTip.text: qsTr("Only the top K most likely tokens will be chosen from") ToolTip.text: qsTr("Only the top K most likely tokens will be chosen from")
ToolTip.visible: hovered ToolTip.visible: hovered
Layout.row: 1 Layout.row: 2
Layout.column: 1 Layout.column: 3
validator: IntValidator { validator: IntValidator {
bottom: 1 bottom: 1
} }
@ -493,7 +541,7 @@ MySettingsTab {
text: qsTr("Max Length:") text: qsTr("Max Length:")
color: theme.textColor color: theme.textColor
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge
Layout.row: 1 Layout.row: 0
Layout.column: 2 Layout.column: 2
} }
MyTextField { MyTextField {
@ -504,7 +552,7 @@ MySettingsTab {
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge
ToolTip.text: qsTr("Maximum length of response in tokens") ToolTip.text: qsTr("Maximum length of response in tokens")
ToolTip.visible: hovered ToolTip.visible: hovered
Layout.row: 1 Layout.row: 0
Layout.column: 3 Layout.column: 3
validator: IntValidator { validator: IntValidator {
bottom: 1 bottom: 1
@ -541,7 +589,7 @@ MySettingsTab {
text: qsTr("Prompt Batch Size:") text: qsTr("Prompt Batch Size:")
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge
color: theme.textColor color: theme.textColor
Layout.row: 2 Layout.row: 1
Layout.column: 0 Layout.column: 0
} }
MyTextField { MyTextField {
@ -552,7 +600,7 @@ MySettingsTab {
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge
ToolTip.text: qsTr("Amount of prompt tokens to process at once.\nNOTE: Higher values can speed up reading prompts but will use more RAM") ToolTip.text: qsTr("Amount of prompt tokens to process at once.\nNOTE: Higher values can speed up reading prompts but will use more RAM")
ToolTip.visible: hovered ToolTip.visible: hovered
Layout.row: 2 Layout.row: 1
Layout.column: 1 Layout.column: 1
validator: IntValidator { validator: IntValidator {
bottom: 1 bottom: 1
@ -588,8 +636,8 @@ MySettingsTab {
text: qsTr("Repeat Penalty:") text: qsTr("Repeat Penalty:")
color: theme.textColor color: theme.textColor
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge
Layout.row: 2 Layout.row: 3
Layout.column: 2 Layout.column: 0
} }
MyTextField { MyTextField {
id: repeatPenaltyField id: repeatPenaltyField
@ -599,8 +647,8 @@ MySettingsTab {
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge
ToolTip.text: qsTr("Amount to penalize repetitiveness of the output") ToolTip.text: qsTr("Amount to penalize repetitiveness of the output")
ToolTip.visible: hovered ToolTip.visible: hovered
Layout.row: 2 Layout.row: 3
Layout.column: 3 Layout.column: 1
validator: DoubleValidator { validator: DoubleValidator {
locale: "C" locale: "C"
} }
@ -636,7 +684,7 @@ MySettingsTab {
color: theme.textColor color: theme.textColor
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge
Layout.row: 3 Layout.row: 3
Layout.column: 0 Layout.column: 2
} }
MyTextField { MyTextField {
id: repeatPenaltyTokenField id: repeatPenaltyTokenField
@ -647,7 +695,7 @@ MySettingsTab {
ToolTip.text: qsTr("How far back in output to apply repeat penalty") ToolTip.text: qsTr("How far back in output to apply repeat penalty")
ToolTip.visible: hovered ToolTip.visible: hovered
Layout.row: 3 Layout.row: 3
Layout.column: 1 Layout.column: 3
validator: IntValidator { validator: IntValidator {
bottom: 1 bottom: 1
} }